Train a model 4 or more input channels with pretrained weights

Hi! I’m trying to train a U-net as unet_learner with 5 number of channels with pretrained model on 3 input channel.
I want only to change conv1 to:
Conv2d(5, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
and init conv1 weight.
Does anyone know exactly how that is done with fastai?

I have not tried increasing input channels with the unet_learner yet, but for create_cnn this works and should give you an idea of how to approach it. I used this extensively in the kaggle human protein identification challenge where images were all RGBY. We are just copying the weights of one of the pretrained channels and applying them to the new channel/s (you can also combine pretrained channel weights however you like, ie, some people were making Y.weight like (R.weight+G.weight) /2.

Here is an example for bninception as a backbone, any will work, just make sure the names of the layers are correct when getting their weights.

class BNInception4D(nn.Module):

def __init__(self,
             arg=None,
             num_classes=28):
    super().__init__()
    self.num_classes = num_classes
    
    enc = ptcv_get_model('bninception', pretrained=True)
    w = enc.features.init_block.conv1.conv.weight
    
    self.features = enc.features
    self.features.init_block.conv1.conv = nn.Conv2d(4, 64, kernel_size=7, stride=1, padding=3, bias=False)
    self.features.init_block.conv1.conv.weight = nn.Parameter(torch.cat((rw, rw[:,:1,:,:]),dim=1))
    self.features.final_pool = nn.AvgPool2d(kernel_size=7, stride=1, padding=2)
    self.output = nn.Linear(in_features=1024, out_features=28, bias=True)

def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    x = self.output(x)
    return x

I don’t think you can pass the class directly to create_cnn (probably unet_learner too) , so we make a little callable that returns the model. Something like this:

def bn4d(arg):
    return BNInception4D(num_classes=28)

and give that to the learner:

create_cnn(data, bn4d, etc...)

Check out this answer