How to create segmentation masks on the fly?

I am trying to do pixel wise segmentation of text in images. For that, I am using images without text and then synthetically adding text on them with random transforms. I would want to generate the segmentation mask based on the text I added to the image without needing to generate x examples and storing to disk.

My code:

class DanbooruImage(ItemBase):
    def __init__(self, image, fileDir, idx): = idx
        self.image = image
        self.fileDir = fileDir
    def __str__(self): return str(self.image)

    def apply_tfms(self, tfms, **kwargs):
        for tfm in tfms:
            tfm(self, **kwargs)

        return self   

class DanbooruSegmentationList(SegmentationItemList):
    def from_textInfo(cls, textInfo: dict, maxItems=10, maxArea=0, **kwargs):
        gen = filter(lambda k: textInfo[k] <= maxArea, textInfo.keys())
        items = [x for _, x in zip(range(maxItems), gen)]
        return DanbooruSegmentationList(items, **kwargs)   

    def get(self, i):
        fileDir = self.items[i]
        image =
        return DanbooruImage(image, fileDir, i)

    def open(self, fileDir):

fonts = Fonts.load(Path('fonts'))
train_fonts, valid_fonts = Fonts(fonts[len(fonts)//10:]), Fonts(fonts[0:len(fonts)//10])

ds = DanbooruSegmentationList.from_textInfo(text_info, maxItems=1000).split_by_rand_pct(0.1, seed=42).empty_label()
data = ds.transform(([partial(textify, fonts=train_fonts)], [partial(textify, fonts=valid_fonts)]))
data = data.databunch(bs=64, collate_fn = custom_collate).normalize(imagenet_stats)
data.c = 2

def getSegmentationMask(dan):
    y, x = dan[0].y_tensor, dan[0].x_tensor
    res = ((y[0] == x[0]) + (y[1] == x[1]) + (x[2] == y[2])) == 3
    res = res.unsqueeze(0)
    return res
def custom_collate(batch):
    return torch.stack(list(map(lambda x: x[0].x_tensor, batch))), torch.stack(list(map(getSegmentationMask, batch))).long()

Note that textify is a function that takes a DanbooruImage and a font and adds text to it, storing the text edited version in x_tensor and the original version in y_tensor.

data.show_batch appears to be working:


But then trying to train unet gives an error:

learn = unet_learner(data, models.resnet18) 

Am I doing something wrong in custom_collate or there is another issue?

1 Like

Hard to tell how you’re returning masks, but for fastai they should be single channel with values giving the class (i.e. a pixel of 0 is for the background class, 1 for the foreground). Also note that there’s a lot of special handling of masks related to properly applying transforms. You are best to use the existing ImageSegment (or a subclass) to inherit this handling. Look at the stuff around that. You can construct an ImageSegment from an array/tensor of data. In particular you might look at the RLE mask handling as this generates ImageSegments dynamically.
You are also probably better to use the standard label_from_* functions (probably label_from_func here) to do the labelling rather than a custom collate function. Unless I’m missing somehting that makes using them hard.

Yeah res.shape is [1, 64, 64] and they are either 0 or 1 and of type long. Which is the same shape I got when trying standard segmentation approach. Problem here is that I don´t know my mask until I finish with image transforms (textify), so I can´t use label_from_func. My images don´t have text initially so all masks would be just 0 without applying the transforms. I could try loading these “empty masks” instead of label_empty, but I still need to override them later. I dont use any standard fastai transform so the special handling isnt an issue, I know what the mask should be after the transforms. Will try creating empty masks instead of using label_empty.

Update: changed the empty_label for

class CustomLabel(SegmentationLabelList):
    def open(self, fn):
        return ImageSegment(torch.zeros(1, 64, 64))
DanbooruSegmentationList._label_cls =  CustomLabel    

.label_const('a', classes=['background', 'text'])

And now all seems to be working! I think unet_learner needed the info of size of segmentation mask to be built properly. @sgugger is this correct? Not sure if I need any other change to do progressive resizing