Multi-task learning in fastai v2 (joint regression and categorization from images)

I am working on a multitask problem, where I feed in an image (in this case, multi-spectral, but not super important) and receive both regression and categorization outputs. However, I’m hitting roadbumps and am curious if people have had success. I think this is very similar to the multi-task problem that was solved by @yang-zhang for fasta v1, but I haven’t seen a fastai v2 solution.

To do this, I created my DataBlock that expects an Image-like input, and then (thanks to n_inp=1) yields two outputs: a RegressionBlock and a CategoryBlock:

our_datablocks = (MSTensorImage(x = 12), RegressionBlock, CategoryBlock)

db = DataBlock(blocks = our_datablocks,
               get_items = find_image_zip_files,
               get_x = get_x_fn,
               get_y = (get_y_fn_age, get_y_fn_disease),
               n_inp=1 # Declare that only the first datablock is input
dls = db.dataloaders(source = path/'', bs = 2)
Collecting items ...
Found 20 items
2 datasets of sizes 16,4
Setting up Pipeline: get_x_fn
Setting up Pipeline: get_y_fn_age -> RegressionSetup -- {'c': None}
Setting up Pipeline: get_y_fn_disease -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}

So far, so good.

But creating the model itself, the problem becomes apparent:

learn = cnn_learner(dls = dls, 
                     arch = resnet18, 
                     loss_func = (MSELossFlat(), CrossEntropyLossFlat()),
                     normalize = False,
                     pretrained = True)

Clearly, unless fastai v2 knows what I mean by that list of loss functions and does a lot of work behind the scenes in cnn_learner, that’s not going to work. And sure enough, it does not:

TypeError                                 Traceback (most recent call last)
<ipython-input-91-e1fdee3ccda7> in <module>
----> 1 learn = cnn_learner(dls = dls, 
      2                      arch = resnet18,
      3                      n_in=12,
      4                      loss_func = (MSELossFlat(), CrossEntropyLossFlat()),
      5                      normalize = False,

/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/fastai/vision/ in cnn_learner(dls, arch, normalize, n_out, pretrained, config, loss_func, opt_func, lr, splitter, cbs, metrics, path, model_dir, wd, wd_bn_bias, train_bn, moms, **kwargs)
    178     if n_out is None: n_out = get_c(dls)
    179     assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
--> 180     model = create_cnn_model(arch, n_out, pretrained=pretrained, **kwargs)
    182     splitter=ifnone(splitter, meta['split'])

/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/fastai/vision/ in create_cnn_model(arch, n_out, pretrained, cut, n_in, init, custom_head, concat_pool, **kwargs)
    144     if custom_head is None:
    145         nf = num_features_model(nn.Sequential(*body.children()))
--> 146         head = create_head(nf, n_out, concat_pool=concat_pool, **kwargs)
    147     else: head = custom_head
    148     model = nn.Sequential(body, head)

/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/fastai/vision/ in create_head(nf, n_out, lin_ftrs, ps, concat_pool, first_bn, bn_final, lin_first, y_range)
     87     if lin_first: layers.append(nn.Dropout(ps.pop(0)))
     88     for ni,no,bn,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], bns, ps, actns):
---> 89         layers += LinBnDrop(ni, no, bn=bn, p=p, act=actn, lin_first=lin_first)
     90     if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))
     91     if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))

/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/fastai/ in __init__(self, n_in, n_out, bn, p, act, lin_first)
    169         layers = [BatchNorm(n_out if lin_first else n_in, ndim=1)] if bn else []
    170         if p != 0: layers.append(nn.Dropout(p))
--> 171         lin = [nn.Linear(n_in, n_out, bias=not bn)]
    172         if act is not None: lin.append(act)
    173         layers = lin+layers if lin_first else layers+lin

/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/torch/nn/modules/ in __init__(self, in_features, out_features, bias)
     76         self.in_features = in_features
     77         self.out_features = out_features
---> 78         self.weight = Parameter(torch.Tensor(out_features, in_features))
     79         if bias:
     80             self.bias = Parameter(torch.Tensor(out_features))

TypeError: new() received an invalid combination of arguments - got (L, int), but expected one of:
 * (*, torch.device device)
      didn't match because some of the arguments have invalid types: (L, int)
 * (torch.Storage storage)
 * (Tensor other)
 * (tuple of ints size, *, torch.device device)
 * (object data, *, torch.device device)

So I think that the next step is to create a custom head and a custom loss function. However, I’m curious if this is already a solved problem and someone has succeeded here already.

1 Like

@jamesp did you find any solution for the above problem? I am stuck at the exact same problem.

1 Like

Unfortunately, I did not. My attention got pulled away by other projects.

@vaibhavgupta did you find any solution for this? I’m also stuck at the exact same problem.

PyTorch, and thus fastai, needs one loss value to pass to backwards. You cannot pass in two loss functions in a list or tuple to a fastai Learner, you need to create a loss function that wraps, calls, and potentially weight, multiple loss functions.

This untested mockup shows the logic (I’d use one of the full featured solutions linked below).

class MLoss(nn.Module):
    def __init__(self, losses, weights):
        self.losses = losses
        self.weights = weights

    def forward(self, preds, labels):
        for i, loss_fn in enumerate(self.losses):
            if i == 0:
                loss = self.weights[i]*loss_fn(preds, labels)
                loss += self.weights[i]*loss_fn(preds, labels)
        return loss

where losses is a list of initialized loss functions, and weights is a list of floats. When training in mixed precision, it’s usually a good idea to keep the individual loss function outputs at the same order of magnitude.

There are full featured solutions. Such as fastxtend’s MultiLoss module, which has a loss wrapper for multiple losses, single label; multiple losses, multiple labels; and a callback to record losses as metrics. And blurr’s MultiTargetLoss which handles multiple losses with multiple labels.