Unet -Segment Label Size

Hi,
In case of multiple class multi label segmentation problem and find that label ,size of y is bs x 1 x h x w. Labels will have multiple kind of masks corresponding to their category.
shouldnt y size be bs x nc x h x w matching that of output which bs X nc X h X w like the simple Non seg multiclass label classification where it is bs x nc same as output…

def no_tfms(self, x, **kwargs): return x
EmptyLabel.apply_tfms = no_tfms

src = (SegmentationItemList.from_df(image_df, path_img,)
       .split_by_rand_pct(valid_pct=0.1, seed=42)
       .label_from_func(get_y_fn, classes=['0','1','2','3','4'])
       .add_test_folder('../test_images')
      )
data = (src.transform(get_transforms(flip_vert=True,max_rotate=360,p_affine=0.80 ), size=size, tfm_y=True,resize_method=ResizeMethod.SQUISH )
       .databunch(bs=16)
       .normalize(imagenet_stats)
       )
1 Like

I don’t think fastai supports multi-label segmentation by default. SegmentationItemList (and SegmentationLabelList) assumes a single label per pixel, so a single channel with each pixel encoded as a value from 0..n_classes-1 rather than n_classes channels with pixels being 0 or 1.

I’ve just created a custom subclass which seems to work for learning (not fully verified due to issues noted below, but loss decreases while a modified dice metric increased as expected). I’m using RLE encoded masks but the key code, adapted off the top of my head for image masks so errors likely, is:

def bce_logits_floatify(input, target, reduction='mean'):
    return F.binary_cross_entropy_with_logits(input, target.float(), reduction=reduction)

class SegmentationMultiLabelList(SegmentationLabelList):
    def __init__(self, items:Iterator, classes, **kwargs):
        super().__init__(items, classes, **kwargs)
        self.loss_func = bce_logits_floatify

src = (SegmentationItemList
       .from_df(image_df, path_img,)
       .split_by_rand_pct(valid_pct=0.1, seed=42)
       .label_from_func(get_y_fn,
                        label_cls=SegmentationMultiLabelList,
                        classes=['1','2','3','4']))

(where I’m assuming your class ‘0’ was for background, not an actual class, if it was an actual class add it back in, but you don’t have a background class with this approach)

Note you may not need bce_logits_floatify (instead directly using the torch function), or may need to play around with data types. My RLE masks were being produced as torch.uint8, though were then converted through augmentation and by ImageSegment ending up as torch.long (I want to look at not being long as this increases the size quite a lot). I suspect your masks should also end up as ImageSegments (through the SegmentationLabelList.open that the multi-label one should inherit), but could be wrong.

While this seems to work for training this will cause issues for evaluation as some parts of the display stuff don’t like the multi-channel outputs produced. I’m about to look at this. Also note that as a sigmoid is applied in the loss function you likely want to apply this to outputs for evaluation. An alternate approach would be to add a sigmoid at the end of your model. The y_range option to the fastai Unet can do this (though will add a little extra overhead above a simple sigmoid as it also does some extra operations for custom range).

Also, as alluded to above the fastai dice metric won’t work with the multi-channel outputs so you need to use a modified version (if interested I can provide mine but I added separate positive and negative dice scores so needed a callback which is a little messier than I’d like and could do with more testing).

1 Like

hi tom thanku…
please help understand me below

  1. What purpose would loss func do inside a label cls tht merely used for getting in the labels if m not wrong… i mean unable to understand how this would help produce nc*hw size label output.

  2. Mask labels which i have are stacked meaning if an image was having two category of masks those were stacked in and when viewed will be viewed into a single picture…

  3. I try to build in modified dice function for multilabel,class but its numerically unstable not yielding the good results .What mistake is there below

     def dice_loss(input, target):  
          n,c = target.shape[0], input.shape[1]
         input = torch.sigmoid(input)
         smooth = 1.0
         iflat = input.view(n,-1)
         tflat = target.repeat(1,c,1,1).view(n,-1).float()
         intersection = (iflat * tflat).sum()
         dice= ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)).mean()
         return (1-dice).float()

The label class will define the default loss function for the learner (which can be overridden on the learner). I gather this is as generally loss functions are tied to label classes (well, more than inputs which are the same between many tasks).

I presume your get_y_fn returns a file name for the mask image for a given input. These masks should contain multi-channel images to work with this method. The loss function is expecting a bs x nc x h x w target mask as you specified. If this is not what is being produced then you would need to look at overriding SegmentationLabelList.open. This is what opens the mask and creates an ImageSegment. This uses open_mask so if that returns nc x h x w images for your masks this method should work.

Your modified dice seems to be expecting targets of size bs x 1 x w x h which you then repeat, so you’d need to resolve that first to do multi-label (at least in the way the code I have does it). The code I have for dice (modified here from a more complex use so may need debugging) is:

def dice(input, target):
    # input and target should both be BxCxHxW
    probs = input.sigmoid()
    probs = probs.view(*probs.shape[:2], -1) # BxCxHW
    truth = target.view(*last_target.shape[:2], -1) # BxCxHW
    assert probs.shape == truth.shape

    p = (probs > threshold).float()
    t = (truth > 0.5).float()
    dice_pos = 2 * (p*t).sum(-1)/((p+t).sum(-1))
    # For empty targets score 1 if prediction empty else 0
    dice_neg = (p.sum(-1) == 0).float()
    pos_index = t.sum(-1) > 0
    dice_pos = torch.masked_select(dice_pos, pos_index)
    dice_neg = torch.masked_select(dice_neg, ~pos_index)
    dice = torch.cat([dice_pos, dice_neg]).mean()
    return np.nan_to_num(dice.item(), 0)

The nan_to_num check on return might not be needed, I was returning positive and negative dice separately so they could be empty and so give nan but this should’t apply to the combined dice.
Also note that following a fairly standard practice this handles empty (negative) masks specially rather than just applying the standard formula as noted in the comment.

3 Likes
  1. I see challenge in building the labels as per output format which is image_Segment will just stack the labels.This is how i was doing but still i was getting same 1 c label,what change i need to do in here…

     class MultiClassSegList(SegmentationLabelList):
         def open(self, id_rles):
             image_id, rles = id_rles[0], id_rles[1:]
             shape = open_image(self.path/image_id).shape[-2:]       
             final_mask = torch.zeros((1, *shape))
             for k, rle in enumerate(rles):
                 if isinstance(rle, str):
                     mask = open_mask_rle(rle, shape).px.permute(0, 2, 1)
                     final_mask += (k + 1) * mask
             return ImageSegment(final_mask)
    
     def load_data(path, df):
         train_list = (SegmentationItemList
                       .from_df(df, path=path/"train_images")
                       .split_by_rand_pct(valid_pct=0.1,seed=42)
                       .label_from_df(cols=list(range(5)), label_cls=MultiClassSegList, classes=[0, 1, 2, 3, 4])
                       .add_test(testfolder.ls(), label=None,tfm_y=False)
                       .transform(get_transforms(flip_vert=True,p_affine=0.8,max_rotate=360), size=128,resize_method =ResizeMethod.SQUISH, tfm_y=True)
    

df= as below


what purpose cols is serving here?
2) Secondly below was version of dice i was using but i felt following issues
a) in case of multilabel mask for a given image will it work
b)input==i & targs==i may never be true if i >1 is my understanding correct ?

     def dice_x(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8)- 
                n,c = targs.shape[0], input.shape[1]
                input = input.argmax(dim=1).view(n,-1)
                targs = targs.view(n,-1)
                intersect,union = [],[]
                for i in range(1,c):
                    intersect.append(((input==i) & (targs==i)).sum(-1).float())
                    union.append(((input==i).sum(-1) + (targs==i).sum(-1)).float())
                intersect = torch.stack(intersect)
                union = torch.stack(union)
                if not iou: return ((2.0*intersect + eps) / (union+eps)).mean()
                else: return ((intersect + eps) / (union - intersect + eps)).mean()

You open is creating a single channel output through final_mask += (k + 1) * mask. You’d want to replace the for loop with something like:

masks = [open_mask_rle(rle,shape).px.permute(0,2,1)
         if isinstance(rle,str)
         else torch.zeros(1, *shape, dtype=torch.uint8)
         for rle in rles]
final_mask = torch.cat(masks)

That dice is for single-label prediction, the argmax on the channel dimension will select the single highest channel for each pixel and use that as a prediction. So that will produce Bx1xHxW predictions not BxCxHxW. You need to use something like the sigmoid and threshold used in the one I gave for multi-label. Something like pred = (torch.sigmoid(input) > threshold).float(). The sigmoid converts input to the range (0,1) then any channels above threshold are predicted as 1 and others 0. This can then be compared to the multi-channel targets.

1 Like

Thanks … yes i just noticed while discussing with you :slight_smile:
you must be seeing df… now…

  1. here how cols work…,with range(5) in label_from_df…,is it taking in all the column names from the corresponding indices ?
  2. when should pass div=True and when false

The cols parameter selects the columns of the dataframe to pass to the label function. So that is passing a list of [0,1,2,3,4] which selects all five columns in the dataframe. So the id_rles in open will get a list of values for all the columns for a single item (row in df).
The div option divides values by 255, this is used if your masks come from images with values of 0 and 255 rather than 0 and 1 (or more commonly for images to convert 8-bit values to float). As you’re using RLE encoded values the open_mask_rle should be using 0 and 1.

Thanks Tom…this is first time m using fai v2 for segment so some basic qs earlier I used was an old one during Airbus competition.
Currently m working on steel defect competition ,joined in lately after finishing aptos with 90 rank official a d 50 unofficial :slight_smile:
Are you also working on steel competition ,if yes would like to team up ? Till date I worked all solo on past competition.

For you and anyone else working on the steel comp, I would note it is not actually mutli-label. If you look at the masks there is no overlap between classes, so the basic fastai single-label/single-channel setup should work fine. Though you can of course also apply the multi-label approach to it.

what you mean by overlap…
There are images init which have more than one mask defined
Prediction are to be made into this format
Image1 mask1,Image1 mask2

Yes, but no single pixel belongs to multiple classes. So you can just use a single channel for model output/target with pixel values of [0,1,2,3,4] (0 being background i.e. no class). Then split the resulting single channel output into multiple masks with something like 'pred_masks = [(pred == i) for i in range(1,5)]` where pred is model output of shape Bx1xHxW, and the result will be 4 separate 1 channel masks with values [0,1].

okay…

  1. when in output label i assign 1,2,3,4 to pixels of 1,0 in masks…
  2. pred==i to be true ,we should do then the argmax as i did previously before making this comparision?

Yes, sorry, the output of your network will actually be BxCxHxW, you then do an argmax to get the Bx1xHxW. This is the approach the fastai segmentation stuff assumes. Be aware if using the interpretation stuff in fastai that it applies argmax for you in some places. So look out for this.
The code you posted was at least generally correct for that approach, and is a workable way to do it. Just be aware that some kernels use the multi-label approach and you can’t mix the two.

Yes…
In a single label approach

  1. Single H*W will have pixels from all the classes ?

  2. if one is true then do i have to multiply class codes to the binary pixels to differentiate codes ,since single image can belong to multiple classes with non overlapping pixels…
    so earlier this was the case ,

     class MultiClassSegList(SegmentationLabelList):
         def open(self, id_rles):
             image_id, rles = id_rles[0], id_rles[1:]
             shape = open_image(self.path/image_id).shape[-2:]       
             final_mask = torch.zeros((1, *shape))
             for k, rle in enumerate(rles):
                 if isinstance(rle, str):
                     mask = open_mask_rle(rle, shape).px.permute(0, 2, 1)
                     final_mask += (k + 1) * mask
             return ImageSegment(final_mask)
    

if the above is the right way for single label approach then if suppose there are two masks of size 2 by 2 one is [[1 0],[1 0]] and second masks
[[0 1],[0 1]] then above would generate the single mask [1 2] [1 2] ?

Yes, that’s correct. You then need to specify a specific background class if your classes aren’t exhaustive as in this case where there isn’t a specific class for every pixel. So, you have classes[‘BG’,‘1’,‘2’,‘3’,‘4’] (noting these are labels, the choice of numbers here doesn’t affect anything).
Fastai will then generate model with 5 channel output and use argmax to create a single channel prediction with values from [0,1,2,3,4] in line with the targets produced like that.

Thanks Tom… discussing with you on this helped me join the pieces of this jigsaw puzzle :slight_smile:

Another interesting thing in problem of segmentations

  1. Does random rotations at many angles could move the object of interest out of the view frame ,provided resize metho is squish
  2. What is difference between Rotations ,and dihedral/Horrizontals flips … I presume both are subset of rotations only say if rotation is 360…

Thanks! Your code helped me building a working multi-label image segmentation on top of fastai. The code is in this Kaggle kernel if you want to have a look at it. The MultiLabelSegmentationLabelList.open() method is derived from this example.

1 Like

Is this application of Dice coefficient intended for multi-label (i.e. classes 0 - n) in a single mask?

It handles multi-label masks as separate channels each valued 0/1 not a single channel. The sort of output that works with torch.nn.BCEWithLogitsLoss.

1 Like