How to do transfer learning with different inputs

Hi everyone,

I’d like to do transfer learning for a resnet on a dataset, but I hit an issue, my dataset has 4 dimensions and images are smaller than 224x224. Most pretrained models I found were preset for 224x224x3. Do you know how I could adapt such a pretrained model for my dataset so I can minimize training time as much as possible?

I was wondering if I could, for instance, resize the first convolutional layer to accept 4 channels and so on… but I’m a bit lost on how to do that or even if it’s a good idea. I’m using pytorch and fast.ai.

Thanks a lot.

3 Likes

This is really tricky and I’m not sure there is any right answer to that. Several strategies:

  • add a first conv layer to make the image gain one channel
  • replace the first layer of the convnet by something that would take 4 channels (and throw away the pretrained weights)

In both those cases, you’ll need to make sure the first layer is in the same layer group as the head of the model, as you don’t want it to be frozen.

2 Likes

Another option is to do a PCA to map 4 dimensions to 3 and then train on resnet. Heard of it here https://twimlai.com/twiml-talk-173-ml-for-understanding-satellite-imagery-at-scale-with-kyle-story/

1 Like

Check out https://www.kaggle.com/iafoss/pretrained-resnet34-with-rgby-0-460-public-lb . The author is using fastai 0.7, and defines a custom ConvNet with an additional filter in the first layer of the pretrained Resnet, initialized with zeros. You might want to do something similar.

3 Likes

I might be wildly wrong here but can’t you do something like:

class ResNet4Channel(nn.Module):
    def __init__(self, encoder_depth, dropout_2d=0.2, pretrained=False):
        super().__init__()
        self.dropout_2d = dropout_2d

        if encoder_depth == 34:
            self.rn = torchvision.models.resnet34(pretrained=pretrained)
            self.bottom_channel_nr = 512
        elif encoder_depth == 101:
            self.rn = torchvision.models.resnet101(pretrained=pretrained)
            self.bottom_channel_nr = 2048
        elif encoder_depth == 152:
            self.rn = torchvision.models.resnet152(pretrained=pretrained)
            self.bottom_channel_nr = 2048
        else:
            raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
        
        self.input_4_3 = ConvBnRelu(4, 3)  # 
        self.input_adjust = nn.Sequential(self.rn.conv1,
                                          self.rn.bn1,
                                          self.rn.relu)
        self.conv1 = self.rn.layer1
        self.conv2 = self.rn.layer2
        self.conv3 = self.rn.layer3
        self.conv4 = self.rn.layer4 
        self.pool = torch.nn.AvgPool2d(kernel_size=16, stride=1, padding=0)
        self.logit_image = nn.Linear(self.bottom_channel_nr, channels_out)
    
    def forward(self, x):
        batch_size, C, H, W = x.shape
        input_4_3 = self.input_4_3(x)
        input_adjust = self.input_adjust(input_4_3)
        conv1 = self.conv1(input_adjust)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        center = self.conv4(conv3)
        pool = self.pool(center).view(batch_size, self.bottom_channel_nr)
        out = self.logit_image(pool)        
        return out

class ConvBnRelu(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),
                                  nn.BatchNorm2d(out_channels),
                                  nn.ReLU(inplace=True)
                                  )

    def forward(self, x):
        return self.conv(x)

So you’ll need to make the forward pass yourself but that’s easy. Apologies if I’ve got the wrong end of the stick.

Mark

Maybe a silly idea but you could train 4 resnet models independently.

Let A, B, C, D be your 4 channels.

Model 1 - Train on A, B, C
Model 2 - Train on B, C, D
Model 3 - Train on C, D, A
Model 4 - Train on D, A, B

For inference run your prediction through all 4 models. Add up class probabilities then normalise to produce list of probabilities.

2 Likes

or even better remove the heads of these 4 resnets and combine them by either CNN or FC NN.

See here for a discussion on how to do such type of stacking of NN models.

3 Likes