DataBlock API and multi-task learning

Hi,

I’m having a foray into multi-task learning and struggling how to put all pieces of DataBlock, transforms, model and loss together. Specifically, my input is an image, and the prediction is (mask, image) where mask is a segmentation mask, and the predicted image is derived from mask through a deterministic function. At high level, I wonder how much of existing fastai v2 abstractions support multi-task learning of such shape and how much would it be an uphill battle of bending abstractions not designed for this?

Other than that, specific questions that popped up for me so far are:

  1. Does it make sense to start with (ImageBlock, MaskBlock) as blocks and expand mask into (mask, image) pair into a custom transform. Can y be a tuple or is it likely to blow up somewhere within training loop?
  2. Should the model return two tensors (a tuple of tensors) or should I concat both mask and image into one tensor and tease them apart in the loss function?

The closest I found in shape of my problem is the Siamese tutorial[0] but with a major difference that Siamese model has complex input (a pair of imgs) but straightforward output.

[0] https://docs.fast.ai/tutorial.siamese.html

Seems to me that the best approach might be to create your own data block. Then you can use the data block API with blocks=(ImageBlock,MultiTaskBlock). I would use the ImageTuple and ImageTupleBlock demonstrated in the tutorial you linked, and incorporate PILMask instead of just PILImage.

You can have the model return as many tensors as you want as long as the loss function can deal with it…

Thanks for the response!

I ended up taking a slightly different route. As a first cut, I had blocks=(ImageBlock,MaskBlock) and a transform that expands PILMask into (PILMask, PILImage). This worked fine until I wanted to get show_batch to work. fastai uses multi-dispatch for extension but you can’t dispatch on the inner types of a tuple, so I created a wrapper for the pair. After all transforms, I have (PILImage, MyWrapper(PILMask, PILImage) which gets converted to tensors. Now, it’s not clear to me whether fastai prefers flat tuples or nested (possibly named like MyWrapper) are fine.

I went with nested ones and encountered many hiccups with implementing a composite loss. E.g. fastai would unwrap MyWrapper unexpectedly (e.g. in show_results) leading to tricky-to-chase-down bugs. It feels like having composite types for inputs is more supported/tested than having composites for targets and my experiment has been an uphill battle. This makes me wonder if I’m bending fastai too far?