UNet with N-Channeled input

I have created a Github repo with new modified unet_learner function code for creating a unet with n channels as input, here is the link to that repository. Check it out, there is a notebook explaining how I have achieved that. It’s pretty simple :smile:.

6 Likes

@jeremy I think it might be useful to include this in fastai-v2.

Shouldn’t the line:

if pretrained and n_input_channels != 3: learn.freeze()

be:

if pretrained and n_input_channels == 3: learn.freeze()

i.e. It should only be frozen if the number of channels is not changed as you describe in the docs.
You could probably also copy over the first channel of the existing weights. I’ve done this for other cases of adapting channel weights successfully Something like new_conv.weights = old_conv.weights[:,0:1,...] (shape is (out_channels,in_channels,*kernel_size)). Otherwise you probably at least want to initialise them with something like apply_init(body[0], nn.init.kaiming_normal_).

2 Likes

Yeah fixed it. Thanks