So if you add channels in the input, you will have extra channels that get concatenated to one of the final layers. With 5 extra input channels, you get 104 instead of 99 in layer 10 it seems.
You could rework the architecture like you did, or maybe use last_cross= False to remove that top skip connection.
If you want to train from scratch, I also have an implementation in my repo with a 4 channel input https://github.com/sdoria/SID
Hey Seb what do you mean by “double check the new layers are still frozen or not”. How can I do that?
I’m working on extending the Unet to 6 channels and trying to make sure that they are all trainable.
My understanding is that when you create a unet learner with a pretrained architecture, part of that arch will be frozen, and you only train the last several layers (unless you unfreeze).
Now if you overwrite some of the earlier layers, I am guessing that those overwritten layers might not be frozen by default.
The unet_input_conv only has 3 input layers. I needed 8 for my task. So I created a new CNN with 8 layers and copied the weights in from the original. I just copied the last layer from the original network to the 5 new layers.
I’m not sure if anyone is still wondering about this, but it is very easy in the new version of fastai (I’m on verstion 2.7.7). When you call your unet_learner, you can simply call the argument n_in, so in this case you would add:
n_in = 4
But if you want your script to be flexible to multiple channels, you could use something like
n_in=dsts[0][0].shape[0]
So that your model will train based on how many channels are in your dataset.