Fastai v2: pretrained UNet with different number of channels

One great thing about FastAI v2 is that the call to unet_learner() allows you to have a different number of input and output channels, instead of assuming 3 channel RGB images. However, it seems that calling this with a pretrained architecture with a number of channels that differs from the pretrained model is not functional, causing errors consistent with having a shape mismatch.

So far, I have been successful at training randomly initialized models. But I’d rather start off with some semblance of the pretrained weights if at all possible. Is there a canonical way in FastAI v2 to accommodate pretrained weights even in the case of shape mismatch (e.g., having more or fewer input channels than the pretrained architecture)?

I don’t think there is a lot you can do if you want to keep all pretrained weights because as soon as you start changing the model architecture, the saved weight matrix will not match anymore.

If you have fewer channels in you training data, you can adapt the data to match the expectations of you model. For example if you have single channel grayscale images you can stack them three times to create a (pseudo) color image. This could be easily done by creating a custom RandTransform class and adding it to batch_tfms or item_tfms

However, if you have more channels in your training data, it will be a challenge. You could load the U-Net encoder with pretrained weights and then swap the model’s stem with a custom made stem. This way you would at least keep the most of the pretrained weights.

Something like this might work:

def custom_model(*args, **kwargs):
    model = resnet18(*args, **kwargs)
    model.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
    return model

unet_learner(dls, custom_model)

Thanks for your reply. But it seems that the internals are expecting 3-channel inputs for the base models? E.g., when I run the following (the goal is a U-Net with 1 input channel images and 2 output channel images):

def custom_resnet18(*args, **kwargs):
    model = resnet18(*args, **kwargs)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    return model

learn = unet_learner(dls = dls, 
                     arch = custom_resnet18, 
                     pretrained = True,

I get the following error:

/opt/conda/lib/python3.7/site-packages/fastai/vision/ in unet_learner(dls, arch, normalize, n_out, pretrained, config, loss_func, opt_func, lr, splitter, cbs, metrics, path, model_dir, wd, wd_bn_bias, train_bn, moms, **kwargs)
    219     img_size = dls.one_batch()[0].shape[-2:]
    220     assert img_size, "image size could not be inferred from data"
--> 221     model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)
    223     splitter=ifnone(splitter, meta['split'])

/opt/conda/lib/python3.7/site-packages/fastai/vision/ in create_unet_model(arch, n_out, img_size, pretrained, cut, n_in, **kwargs)
    194     "Create custom unet architecture"
    195     meta = model_meta.get(arch, _default_meta)
--> 196     body = create_body(arch, n_in, pretrained, ifnone(cut, meta['cut']))
    197     model = models.unet.DynamicUnet(body, n_out, img_size, **kwargs)
    198     return model

/opt/conda/lib/python3.7/site-packages/fastai/vision/ in create_body(arch, n_in, pretrained, cut)
     64     "Cut off the body of a typically pretrained `arch` as determined by `cut`"
     65     model = arch(pretrained=pretrained)
---> 66     _update_first_layer(model, n_in, pretrained)
     67     #cut = ifnone(cut, cnn_config(arch)['cut'])
     68     if cut is None:

/opt/conda/lib/python3.7/site-packages/fastai/vision/ in _update_first_layer(model, n_in, pretrained)
     51     first_layer, parent, name = _get_first_layer(model)
     52     assert isinstance(first_layer, nn.Conv2d), f'Change of input channels only supported with Conv2d, found {first_layer.__class__.__name__}'
---> 53     assert getattr(first_layer, 'in_channels') == 3, f'Unexpected number of input channels, found {getattr(first_layer, "in_channels")} while expecting 3'
     54     params = {attr:getattr(first_layer, attr) for attr in 'out_channels kernel_size stride padding dilation groups padding_mode'.split()}
     55     params['bias'] = getattr(first_layer, 'bias') is not None

AssertionError: Unexpected number of input channels, found 1 while expecting 3

Which appears to be an assertion in the _update_first_layer code that expects that the first convolutional layer has a 3-channel input. So by pre-changing the first convolutional layer to expect 1 channel input, that assertion fails.

This actually seems to imply that the fastai v2 U-Net is designed to allow you to use the functionality with an arbitrary number of channels, and so now I need to investigate whether my original error is actually totally unrelated to this functionality.

Aha. I think my problem was that I wasn’t passing normalize=False. When I do pass that, the unet is created without a problem. So, in my case (1 input channel, 2 output channels), this is working:

learn = unet_learner(dls = dls, 
                     arch = resnet18, 
                     pretrained = True,
                     normalize = False,
                     loss_func = fastai.losses.MSELossFlat())
1 Like

Cool. I checked the source for the unet_learner. Apparently, the weights of the layer are added together, if the number of channels is smaller then 3, otherwise the weights of the additional channels are initialized with 0. This is much better then just replacing the first layer of the model with a new, randomly initialized ConvLayer.

def _load_pretrained_weights(new_layer, previous_layer):
"Load pretrained weights based on number of input channels"
n_in = getattr(new_layer, 'in_channels')
if n_in==1:
    # we take the sum =, keepdim=True)
elif n_in==2:
    # we take first 2 channels + 50% =[:,:2] * 1.5
    # keep 3 channels weights and set others to null[:,:3] =[:,3:].zero_()
1 Like

Nice - that’s useful, thank you for surfacing that info. I’m curious whether a better initialization for N > 3 channel weights would be to (e.g.) average the 3 color channel weights instead of leaving the new channels with weight = 0. But I don’t currently have a N>3 channel problem.