Cool. I checked the source for the unet_learner
. Apparently, the weights of the layer are added together, if the number of channels is smaller then 3, otherwise the weights of the additional channels are initialized with 0. This is much better then just replacing the first layer of the model with a new, randomly initialized ConvLayer.
def _load_pretrained_weights(new_layer, previous_layer):
"Load pretrained weights based on number of input channels"
n_in = getattr(new_layer, 'in_channels')
if n_in==1:
# we take the sum
new_layer.weight.data = previous_layer.weight.data.sum(dim=1, keepdim=True)
elif n_in==2:
# we take first 2 channels + 50%
new_layer.weight.data = previous_layer.weight.data[:,:2] * 1.5
else:
# keep 3 channels weights and set others to null
new_layer.weight.data[:,:3] = previous_layer.weight.data
new_layer.weight.data[:,3:].zero_()