Unet with 4 channel input

I would like to use unet for images with more input channels than 3, for example 4.
Currently, I do it like this:

input_channel = 4
resnet = models.resnet34 ()
resnet.conv1 = torch.nn.Conv2d (input_channel, 64, kernel_size = (7, 7), stride = (2, 2), padding = (3, 3), bias = False)
resnet.avgpool = nn.AdaptiveAvgPool2d ((1, 1))
learn = unet_learner (date, lambda x: resnet, wd = wd, last_cross = False)

Is there any possibility that the last layer last_cross would accept more channels?

1 Like

I’m trying to do the same thing you are, except I changed the first layer after calling unet_learner instead of before:

learn.model[0][0] =nn.Conv2d(4,64,kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

You might have to run the following afterwards:

learn.layer_groups[0][0] = learn.model[0][0]

learn.model.cuda();

I’m not sure if there are other side effects that I’m not aware of.

Back to your question, if you want to use last_cross = True, you could try to change the corresponding layer to accept the right number of channels.

Hi Seb,

I’m trying to add an 8 channel input to my unet. I’ve replaced the first conv2d in the same fashion that you have. I’m getting the following error:

RuntimeError: Given groups=1, weight of size 99 99 3 3, expected input[16, 104, 256, 256] to have 99 channels, but got 104 channels instead

Looks like that must be happening at this layer of the network

(10): SequentialEx(
      (layers): ModuleList(
        (0): Sequential(
          (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace)
        )
        (1): Sequential(
          (0): Conv2d(99, 99, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace)
        )
        (2): MergeLayer()
      )
    )

Any hints?

Thanks

1 Like

@ kmartyn Did you get your 4 channel input to work?

I figured it out, here’s what I did. Although I’m not entirely sure why layers 10 and 11 needed to be updated.

learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics,
                 blur=False, norm_type=NormType.Weight)
#Change the input layer of the learner to take 8 channels rather than the normal 3
unet_input_conv = learn.model[0][0]
new_input = nn.Conv2d(8, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
for i in range(3):
    new_input.weight[:,i] = unet_input_conv.weight[:,i]
for i in range(3,8):
    new_input.weight[:,i] = unet_input_conv.weight[:,2]
new_input.weight = nn.Parameter(new_input.weight.detach().requires_grad_(True))

learn.model[0][0] = new_input
learn.layer_groups[0][0] = learn.model[0][0]
learn.model[10][0][0] = nn.Conv2d(104, 104, kernel_size=(3,3), stride=(1,1), padding=(1,1))
learn.model[10][1][0] = nn.Conv2d(104, 104, kernel_size=(3,3), stride=(1,1), padding=(1,1))
learn.model[11][0] = nn.Conv2d(104, 2, kernel_size=(1,1), stride=(1,1))

Unet has long skip connections. See Fig 1 in the original paper (https://arxiv.org/pdf/1505.04597.pdf)

So if you add channels in the input, you will have extra channels that get concatenated to one of the final layers. With 5 extra input channels, you get 104 instead of 99 in layer 10 it seems.

You could rework the architecture like you did, or maybe use last_cross= False to remove that top skip connection.

If you want to train from scratch, I also have an implementation in my repo with a 4 channel input https://github.com/sdoria/SID

2 Likes

Also if you are overwriting layers, double check the new layers are still frozen or not and properly initialized.

Hey Seb what do you mean by “double check the new layers are still frozen or not”. How can I do that?
I’m working on extending the Unet to 6 channels and trying to make sure that they are all trainable.

My understanding is that when you create a unet learner with a pretrained architecture, part of that arch will be frozen, and you only train the last several layers (unless you unfreeze).

Now if you overwrite some of the earlier layers, I am guessing that those overwritten layers might not be frozen by default.

Looks like learn.summary() will do the job of telling you what’s trainable
https://docs.fast.ai/basic_train.html#Learner.freeze

And if needed, possibly learn.freeze() will refreeze your first group of layers, learn.unfreeze() will unfreeze everything.

Hey @brian can you please explain what this line, should it not read unet_input_conv.weight[:,i] instead of unet_input_conv.weight[:,2]

The unet_input_conv only has 3 input layers. I needed 8 for my task. So I created a new CNN with 8 layers and copied the weights in from the original. I just copied the last layer from the original network to the 5 new layers.

more cleaner implementation: link