One Hot Encoded segmentation (Multi-label)

TLDR; Best way to output one-hot segmentation masks for multi-class/ multi-label

A quick few definitions just to make sure we are on the same page!

I have a multi-class (there is more 1 class: I have 3 classes + background), multi label (each pixel can be in multiple classes at once) segmentation problem

The image data is RGB images and the masks are one-hot encoded PNGs

The problem:

Passing the mask through the MaskBlock will flatten the mask and output: BS, W, H

Which works if each pixel only has 1 possible value - but as they have multiple it resolves this by adding new classes (so if there pixel with 1,2 it will output a new class 4 etc…)

I have tried to change the maskblock by setting the PILMask.create to use ‘RGB’ instead of ‘L’

RGBmask = partial(PILMask.create, mode='RGB')`

def MaskBlock2(codes=None):
return TransformBlock(type_tfms=RGBmask,item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)

dblock = DataBlock(blocks =(ImageBlock,MaskBlock2),
              get_items = get_image_files,
              get_y = get_y_fn,

But this doesn’t seem to make any difference to the dls output.

I have made a work around by using pytorch datasets, but have put that in another question.

My impression is that PILMasks REALLY don’t want to be one hot encoded (or have anything other than 1 channel)

The Question(s):

  1. Am I missing an easy way of one-hot encoding segmentation masks with datablocks?

  2. What is the best way of loading them so I can retain all the extra features of the fastai dataloaders?

(sorry it’s a big question!)

Hello, I’m kind of in the same situation that you were in, and I’ve resorted to using a custom transform that loads the mask as n-channels image.

I wonder if you found a different solution, because this pipeline seems to be pretty good at loading the data, but it doesn’t allow me to use the native unet_learner function, and while I’ve also found a way to make my own Learner using DynamicUnet, the learner seems to be diverging pretty fast.

Hi , I’m also stuck with the same issue. I tried to first convert pytorch dataloaders to fastai dataloaders but then i lose all the fastai magic. Did you guys find any solution for it ?