I don’t think fastai supports multi-label segmentation by default. SegmentationItemList
(and SegmentationLabelList
) assumes a single label per pixel, so a single channel with each pixel encoded as a value from 0..n_classes-1
rather than n_classes
channels with pixels being 0 or 1.
I’ve just created a custom subclass which seems to work for learning (not fully verified due to issues noted below, but loss decreases while a modified dice metric increased as expected). I’m using RLE encoded masks but the key code, adapted off the top of my head for image masks so errors likely, is:
def bce_logits_floatify(input, target, reduction='mean'):
return F.binary_cross_entropy_with_logits(input, target.float(), reduction=reduction)
class SegmentationMultiLabelList(SegmentationLabelList):
def __init__(self, items:Iterator, classes, **kwargs):
super().__init__(items, classes, **kwargs)
self.loss_func = bce_logits_floatify
src = (SegmentationItemList
.from_df(image_df, path_img,)
.split_by_rand_pct(valid_pct=0.1, seed=42)
.label_from_func(get_y_fn,
label_cls=SegmentationMultiLabelList,
classes=['1','2','3','4']))
(where I’m assuming your class ‘0’ was for background, not an actual class, if it was an actual class add it back in, but you don’t have a background class with this approach)
Note you may not need bce_logits_floatify
(instead directly using the torch function), or may need to play around with data types. My RLE masks were being produced as torch.uint8
, though were then converted through augmentation and by ImageSegment
ending up as torch.long
(I want to look at not being long as this increases the size quite a lot). I suspect your masks should also end up as ImageSegment
s (through the SegmentationLabelList.open
that the multi-label one should inherit), but could be wrong.
While this seems to work for training this will cause issues for evaluation as some parts of the display stuff don’t like the multi-channel outputs produced. I’m about to look at this. Also note that as a sigmoid is applied in the loss function you likely want to apply this to outputs for evaluation. An alternate approach would be to add a sigmoid at the end of your model. The y_range
option to the fastai Unet can do this (though will add a little extra overhead above a simple sigmoid as it also does some extra operations for custom range).
Also, as alluded to above the fastai dice metric won’t work with the multi-channel outputs so you need to use a modified version (if interested I can provide mine but I added separate positive and negative dice scores so needed a callback which is a little messier than I’d like and could do with more testing).