Multiple ouputs for one input in segmentation problem

Hello !
I don’t know if I’m posting this in the right category, but here is my problem :

I have as inputs 4 channel images with aerial images for the three first channels, and a raster of a generated displaced polygon as the 4th channel.
I want to predict an affine transformation matrix and the segmentation mask.
I plugged the Spatial Transformation Network into a regular Unet, where I output the mask and the affine transformation matrix.

I already look into a few ways to have multiple losses in fastai so that the two outputs are optimized.

But I don’t know how to have two labels/target, one for the segmentation (of shape (256,256,3) and one for the affine transformation (of shape (2,3)). I’m guessing I should look into the data block API but I don’t really know in which direction I should go…

You can’t really have two labels, but you can have a custom label which can be whatever you want, and can be a tuple. Custom labels are created with a LabelList subclass. This can be specified in the label_cls parameter to the various label_from_* functions (or defaults to one specified by the ItemList you use). In the LabelList subclass you can define a self.loss_func which is used to calculate losses. It will be passed the outputs of your network and whatever labels you return from it. For segmentation this is SegmentationLabelList.
You can return a tuple as a label, but I haven’t tried this. You may have to provide a custom collate function to do this. Or as you are creating the affine transform from the raster image you could just return the 4 channel image as a label and then calculate the affine in your loss function (everything in ItemList.open is run again every batch so performance is the same).
Note that SegmentationLabelList expects masks to be a single channel with values from 0..n_classes, so you’ll likely need to override that. See this thread for a custom SegmentationLabelList subclass I wrote to deal with multi-channel segmentation masks with a custom loss function. This should give you most of the pieces you need from the fastai end, you just need to put in your own open handling and loss function (and custom show if you want show_batch to work).
Also note that fastai will by default apply transforms to SegmentationLabelList items. This might mess with your affine stuff so you might want to disable that, or at least limit to only flips.

1 Like