Fastai v2: pretrained UNet with different number of channels

Cool. I checked the source for the unet_learner. Apparently, the weights of the layer are added together, if the number of channels is smaller then 3, otherwise the weights of the additional channels are initialized with 0. This is much better then just replacing the first layer of the model with a new, randomly initialized ConvLayer.

def _load_pretrained_weights(new_layer, previous_layer):
"Load pretrained weights based on number of input channels"
n_in = getattr(new_layer, 'in_channels')
if n_in==1:
    # we take the sum
    new_layer.weight.data = previous_layer.weight.data.sum(dim=1, keepdim=True)
elif n_in==2:
    # we take first 2 channels + 50%
    new_layer.weight.data = previous_layer.weight.data[:,:2] * 1.5
else:
    # keep 3 channels weights and set others to null
    new_layer.weight.data[:,:3] = previous_layer.weight.data
    new_layer.weight.data[:,3:].zero_()
1 Like