Modify _load_pretrained_weights for multispectral images

In the _load_pretrained_weigths method for fastai v2, there are special cases to handle weights when the number of input channels is less than 3:

    if n_in==1:
        # we take the sum
        new_layer.weight.data = previous_layer.weight.data.sum(dim=1, keepdim=True)
    elif n_in==2:
        # we take first 2 channels + 50%
        new_layer.weight.data = previous_layer.weight.data[:,:2] * 1.5

However, if the number of channels is greater than 3, nothing special happens for most of them and they are zeroed:

    else:
        # keep 3 channels weights and set others to null
        new_layer.weight.data[:,:3] = previous_layer.weight.data
        new_layer.weight.data[:,3:].zero_()

My intuition is that we might gain some benefit by treating the rest of the channels as if they were a grayscale image as in the 1-channel case, e.g., something akin to (conceptually - am not sure this compiles; probably need to replicate the grayscale data across the number of remaining dimensions):

    else:
        # keep 3 channels weights and set others to a grayscale equivalent
        new_layer.weight.data[:,:3] = previous_layer.weight.data
        new_layer.weight.data[:,3:] = previous_layer.weight.data.sum(dim=1, keepdim=True)

Has anyone attempted this? Did it help with transfer learning to multispectral data?

Actually it looks like @Nickelberry has solved this problem with a similar solution and discussed in this thread. My main question at this point is whether that ought to be the default, i.e., whether it’s worth a PR for fastai?

1 Like