Unet: how to skip max pool layer in resnet34 encoder (ValueError)

Hi,

TLDR: I want to use a resnet34 encoder as part of a Unet model but skip the MaxPool2d 4th layer (i.e. near the start of the model) on the downward pass (i.e. so don’t use self.sfs to save the encoder’s activations). This will keep the channels of the resnet34 intact but mean the spatial size out the middle of the encoder is now 8x8 as opposed to 4x4.

i.e. I want to grab the resnet34 layers and define a manual forward pass through the model.

In fastai I am doing:


class UNetResNetExample(nn.Module):
    def __init__(self, rn, p=0.2):
        super().__init__()
        bottom_channel_nr, num_filters = 512, 16
        self.rn = rn  # <class 'torch.nn.modules.container.Sequential'>        
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU(inplace=True)
        # Start grabbing resnet34 layers:
        self.input_adjust = nn.Sequential(self.rn[0],
                                          self.rn[1],
                                          self.rn[2])  # here we want to skip self.rn[3] which is a max pooling layer

        self.conv1 = self.rn[4]
        self.conv2 = self.rn[5]
        self.conv3 = self.rn[6]
        self.conv4 = self.rn[7]

        self.dec4 = Decoder(bottom_channel_nr,                        num_filters * 8 * 2, num_filters * 8)
        self.dec3 = Decoder(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
        self.dec2 = Decoder(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2)
        self.dec1 = Decoder(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2)
        self.final = nn.Conv2d(num_filters * 2 * 2, 1, kernel_size=1)

    def forward(self, x):
        input_adjust = self.input_adjust(x)
        conv1 = self.conv1(input_adjust)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        center = self.conv4(conv3)
        dec4 = self.dec4(center)
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = F.dropout2d(self.dec1(torch.cat([dec2, conv1], 1)), p=self.dropout_2d)
        return self.final(dec1)

    def close(self):
        for sf in self.sfs: sf.remove() 

However by doing this I am getting the following error, which I’ve been unable to get to the bottom of:


Traceback (most recent call last):

  File "<ipython-input-42-cac343dfa8c8>", line 44, in <module>
    use_clr=None, callbacks=callbacks_list)

  File "/home/maw501/ML/fastai/fastai/learner.py", line 303, in fit
    layer_opt = self.get_layer_opt(lrs, wds)

  File "/home/maw501/ML/fastai/fastai/learner.py", line 275, in get_layer_opt
    return LayerOptimizer(self.opt_fn, self.get_layer_groups(), lrs, wds)

  File "/home/maw501/ML/fastai/fastai/layer_optimizer.py", line 15, in __init__
    self.opt = opt_fn(self.opt_params())

  File "/home/maw501/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/optim/adam.py", line 29, in __init__
    super(Adam, self).__init__(params, defaults)

  File "/home/maw501/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/optim/optimizer.py", line 39, in __init__
    self.add_param_group(param_group)

  File "/home/maw501/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/optim/optimizer.py", line 169, in add_param_group
    raise ValueError("some parameters appear in more than one parameter group")

ValueError: some parameters appear in more than one parameter group

Any help is greatly appreciated and if clarification is required, please just ask.

UPDATE: I have found a workaround for this which I’ll post later.

Thanks,

Mark