One Hot Encoded segmentation (Multi-label)

TLDR; Best way to output one-hot segmentation masks for multi-class/ multi-label

A quick few definitions just to make sure we are on the same page!

I have a multi-class (there is more 1 class: I have 3 classes + background), multi label (each pixel can be in multiple classes at once) segmentation problem

The image data is RGB images and the masks are one-hot encoded PNGs

The problem:

Passing the mask through the MaskBlock will flatten the mask and output: BS, W, H

Which works if each pixel only has 1 possible value - but as they have multiple it resolves this by adding new classes (so if there pixel with 1,2 it will output a new class 4 etc…)

I have tried to change the maskblock by setting the PILMask.create to use ‘RGB’ instead of ‘L’

RGBmask = partial(PILMask.create, mode='RGB')`

def MaskBlock2(codes=None):
return TransformBlock(type_tfms=RGBmask,item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)

dblock = DataBlock(blocks =(ImageBlock,MaskBlock2),
              get_items = get_image_files,
              get_y = get_y_fn,
              batch_tfms=batch_tfms)

But this doesn’t seem to make any difference to the dls output.

I have made a work around by using pytorch datasets, but have put that in another question.

My impression is that PILMasks REALLY don’t want to be one hot encoded (or have anything other than 1 channel)

The Question(s):

  1. Am I missing an easy way of one-hot encoding segmentation masks with datablocks?

  2. What is the best way of loading them so I can retain all the extra features of the fastai dataloaders?

(sorry it’s a big question!)