voliv
October 28, 2018, 11:27pm
1
Hi everyone,
I’d like to do transfer learning for a resnet on a dataset, but I hit an issue, my dataset has 4 dimensions and images are smaller than 224x224. Most pretrained models I found were preset for 224x224x3. Do you know how I could adapt such a pretrained model for my dataset so I can minimize training time as much as possible?
I was wondering if I could, for instance, resize the first convolutional layer to accept 4 channels and so on… but I’m a bit lost on how to do that or even if it’s a good idea. I’m using pytorch and fast.ai.
Thanks a lot.
3 Likes
This is really tricky and I’m not sure there is any right answer to that. Several strategies:
add a first conv layer to make the image gain one channel
replace the first layer of the convnet by something that would take 4 channels (and throw away the pretrained weights)
In both those cases, you’ll need to make sure the first layer is in the same layer group as the head of the model, as you don’t want it to be frozen.
2 Likes
mrandy
(Andrei)
October 29, 2018, 5:11am
3
1 Like
aakashns
(Aakash N S)
October 29, 2018, 8:34am
4
Check out https://www.kaggle.com/iafoss/pretrained-resnet34-with-rgby-0-460-public-lb . The author is using fastai 0.7, and defines a custom ConvNet with an additional filter in the first layer of the pretrained Resnet, initialized with zeros. You might want to do something similar.
4 Likes
maw501
(Mark Worrall)
October 29, 2018, 8:41pm
5
I might be wildly wrong here but can’t you do something like:
class ResNet4Channel(nn.Module):
def __init__(self, encoder_depth, dropout_2d=0.2, pretrained=False):
super().__init__()
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.rn = torchvision.models.resnet34(pretrained=pretrained)
self.bottom_channel_nr = 512
elif encoder_depth == 101:
self.rn = torchvision.models.resnet101(pretrained=pretrained)
self.bottom_channel_nr = 2048
elif encoder_depth == 152:
self.rn = torchvision.models.resnet152(pretrained=pretrained)
self.bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.input_4_3 = ConvBnRelu(4, 3) #
self.input_adjust = nn.Sequential(self.rn.conv1,
self.rn.bn1,
self.rn.relu)
self.conv1 = self.rn.layer1
self.conv2 = self.rn.layer2
self.conv3 = self.rn.layer3
self.conv4 = self.rn.layer4
self.pool = torch.nn.AvgPool2d(kernel_size=16, stride=1, padding=0)
self.logit_image = nn.Linear(self.bottom_channel_nr, channels_out)
def forward(self, x):
batch_size, C, H, W = x.shape
input_4_3 = self.input_4_3(x)
input_adjust = self.input_adjust(input_4_3)
conv1 = self.conv1(input_adjust)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
center = self.conv4(conv3)
pool = self.pool(center).view(batch_size, self.bottom_channel_nr)
out = self.logit_image(pool)
return out
class ConvBnRelu(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
So you’ll need to make the forward pass yourself but that’s easy. Apologies if I’ve got the wrong end of the stick.
Mark
1 Like
maral
November 7, 2018, 7:16am
6
Maybe a silly idea but you could train 4 resnet models independently.
Let A, B, C, D be your 4 channels.
Model 1 - Train on A, B, C
Model 2 - Train on B, C, D
Model 3 - Train on C, D, A
Model 4 - Train on D, A, B
For inference run your prediction through all 4 models. Add up class probabilities then normalise to produce list of probabilities.
2 Likes
hwasiti
(Haider Alwasiti)
November 16, 2018, 2:02am
7
maral:
Maybe a silly idea but you could train 4 resnet models independently.
Let A, B, C, D be your 4 channels.
Model 1 - Train on A, B, C
Model 2 - Train on B, C, D
Model 3 - Train on C, D, A
Model 4 - Train on D, A, B
For inference run your prediction through all 4 models. Add up class probabilities then normalise to produce list of probabilities.
or even better remove the heads of these 4 resnets and combine them by either CNN or FC NN.
See here for a discussion on how to do such type of stacking of NN models.
3 Likes