Muti input model training

I am currently building a model that takes 2 inputs and produces 2 outputs.
My model takes as input a tuple of tuples:
input: ((img1, img2, img3), (val1, val2, val3))
output: ((img4, img5, img6), (val4, val5, val6)

Just for context, the model is a encoder-decoder of type Informer with encoder vision transformer (TimesFormer) and decoder a regular transformer.

I have been playing the last year with these types of model with fastai v2 without issues, but now that I have targets that are multimodal, I am writing very complicated tricks to make the learner work. I am probably missing something.

My model is trained with 2 losses, CE for the images and MSE for the values. So I initially tried just building a custom loss, but this does not work.

I ended up making a ugly callback:

class DualLoss(Callback):
    def __init__(self, other_loss=CEntropyLossFlat()):
        store_attr()
    def before_batch(self):
        self.yb_images, self.learn.yb = self.yb[0][0], (self.yb[0][1], )
    
    def after_pred(self):
        self.pred_images, self.learn.pred = self.pred
        
    def after_loss(self):
        self.learn.loss += self.other_loss(self.pred_images, self.yb_images)

as you can see, I hack the learner storing the predicted images away, to add the later to the loss of the learner (the MSE over the values).

Here’s an example where I did a bit of this:

Relevant parts I modified/wrote: Dataloader, model, loss function, metric, and Callback.

I found that for the loss (search on “Crop_Loss” in my document), I didn’t subclass it from Callback or any other class (so in your case, just class DualLoss:). But it does need a __call__() method. This was patterned after some of the loss functions in fastai. I also wrote a Callback for sending predictions to WandB.

1 Like

I defined my own tuple that supports this mixed multi modal input, this is not an issue, the problem is with the targets. fastai does not like targets that are multiple types.

I recently did a project with multiple inputs and outputs, and while it took some coding to set up, fastai works fine with multiple inputs, multiple targets, or multiple inputs and targets. I didn’t use a callback, not sure if that would make things easier or harder.

I’ll snip the task specific code and replicate how I split the data.

My datablock was setup like this, three inputs and three targets (two in the regression block).

block = DataBlock(blocks=(ImageBlock, 
                          MaskBlock, 
                          MaskBlock,
                          MaskBlock,
                          RegressionBlock),
                  # getters, item_tfms, batch_tfms, etc here
                  n_inp=3)

Model forward would split the three inputs then process them.

def forward(self, *x):
     x, y, z = x[0], x[1], x[2]
     # pass through the model. Regression output (y) was tensor of size [b, 2]
     return (x,y)

Loss forward would then split the tuple from the model and the targets from the dataloader (both with Mask and Regression Tensors).

def forward(self, x, *y):
     x1, x2 = x
     y1, y2 = y
     # iterate through appropriate losses then return single loss

Similarly metrics would need to split inputs too.

Hope that helps.

This is pretty cool, the multi block DataBlock. In my case, it is more complex, as the inputs are tuples of inputs.


As you can see, my input is:

  • ((Image, Image, Image, Image), Tensor)
    and the output:
  • ((Mask, Mask, Mask, Mask, Mask, Mask), Tensor)

For some reason, fastai does not like this. it appears that the mixed precision Callback is the problem, when I run in 32bit it works!

~/Apps/fastai/fastai/callback/fp16.py in after_pred(self)
     20     def before_batch(self): self.autocast.__enter__()
     21     def after_pred(self):
---> 22         if listify(self.pred)[0].dtype==torch.float16: self.learn.pred = to_float(self.pred)
     23     def after_loss(self): self.autocast.__exit__(None, None, None)
     24     def before_backward(self): self.learn.loss_grad = self.scaler.scale(self.loss_grad)

AttributeError: 'tuple' object has no attribute 'dtype'

I am using this as loss:

class DualLoss:
    def __init__(self, loss1, loss2): store_attr()
    def __call__(self, x, y):
        x1, x2 = x
        y1, y2 = y
        return self.loss1(x1,y1) + self.loss2(x2,y2)

I will propose a fix to this.

Proposed fix: