Image Channels

At The end of lesson 3 Jeremy says that pre-trained imagenet models expect 3 channel (RGB) images . But I have used resnet34 with 1 channel datasets with above 99% accuracy (MNIST for example). So I can’t understand this contradiction ?

Does FastAi automatically create more channels ?
Or have I got something wrong here ?

How were you building it’s databunch? Were you normalizing with ImageNet stats? That will turn it into 3 channels (you can also modify resnet’s input to be one channel instead of 3)

1 Like

Yes. I used imagenet_stats to normalize the databunch.

How can I modify it ?

Thank You.

The simplest way is to do something like so:

body = create_body(resnet34, pretrained=True)
l = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2),
                    padding=(3,3), bias=False)
l.weight = nn.Parameter(l.weight.sum(dim=1, keepdim=True))
body[0] = l

So essentially we make the first layer a single channel conv2d and adjust the weights so we can use the pretrained weights

1 Like

Thank You

Hi Zachary. IMHO, your verbal description is exactly right, but the code does not do what you describe.

Yes I believe I’m missing using the weights from the body itself IIRC? (Sorry jumping all around the place so it was a quick answer, also tell me if you think that’s what I missed :slight_smile: )

Right. You need to take the sum across channels of the pretrained first layer weights and initialize those values into the new first layer. If my understanding is right, this will have the same effect as sending the same b/w image to all three channels of the original resnet.

Thanks for solving so many problems presented on these forums. But I won’t attempt to correct the code because I’d probably make a mistake!

1 Like

That’s okay! We all make them :slight_smile: Let me try to correct the above. I believe it should be something like so:

w = body[0] # check if body[0] is a 3x64 convlayer
l = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2),
                    padding=(3,3), bias=False)
l.weight = nn.Parameter(w.weight.sum(dim=1, keepdim=True))
1 Like