What is 'n_in' actually doing?

Hi All :wave:

Recently I have been playing around with 6 band/channel imagery, for image classification. I have my dataloader and augmentations working as expected and I’m able to train a model with pretrained weights, My learner looks like this.

learn = cnn_learner(dl, resnet18, n_in=6 , pretrained=True, metrics=error_rate).to_fp16()

Somewhat frustratingly I’m not getting any better performance than when I was just using 3 band imagery, which got we wondering what is ‘n_in’ actually doing? is it duplicating the weights from one of the pretrained bands (RGB) or is it initialised with random weights?

I have done some searching but as of yet I haven’t found much useful info on ‘n_in’.

I think this is important to know as if it’s starting from random weights I imagine I will need to unfreeze and fit the entire model right?

For context my 6 band imagery is two different satellite RGB images of the same location merged into one image/tensor.

Any help would be appreciated I’m pretty lost here.

Cheers

Hi Nick. It’s me again. In a case like this the best course (IMHO) is to read the code. You can do this from the fastai docs. Even better is to put a debugger, such as PyCharm or VSCode, on cnn_learner and trace to see exactly what is happening.

A quick glance shows that the extra three channel weights are set to zeros. So yes you would need to train them. But you certainly do not have to do it this way. A valid alternative is to initialize the new 3 to the pretrained RGB values and divide all weights by 2. You can let fastai create the Learner, and simply alter the model to your liking inside your notebook, using the fastai code as a template.

But my point is that the best way to answer questions about what n_in actually does it to look at the code!

HTH, :slightly_smiling_face:

P.S. I understand that getting into the fastai code is intimidating. I can only follow parts of it even after a couple of years. But given the level of the questions you want to answer, there does not seem to be any better way.

2 Likes

Hi Malcolm,
As always thanks for chiming in! I have taken a look at the code and I think I have a solution by modifying ‘_load_pretrained_weights’ but I’m not sure about the best approach to implement my alteration. Could I just ‘@patch’ the function?

Or I’m I thinking about this wrong? You seam to be suggesting I can edit the model in the notebook, how do I do that?

Thanks for the help :slight_smile:

Hi. There are many ways to go about it.

To be honest, I never use @patch. I have many times been thrown off when a well-known Python function behaves differently than documented because fastai has silently modified it behind the scenes. So I have a bias against monkey patching as a general practice, and personally consider it a questionable programming pattern. It makes code concise but at the cost of confusion for whoever follows. YMMV of course.

In this case, I’d create learn with cnn_learner. Then you can get to learn.model, find the first layer’s weights, and modify them. The first three channels will have the pretrained weights, and the rest will be zeros. Convolutions add up the channels, so you will need to scale the weights by half to use six channels instead of three. Adjust weights as you wish and forge ahead. This is just my own style - you may prefer a different approach.

BTW, there are several forum posts that demonstrate how to edit an existing model.

:slightly_smiling_face:

Thanks Malcolm,
Ok that sounds reasonable, I wasn’t having much luck trying to used @patch anyway haha :laughing:.

I have been having a look around this forum and the PyTorch forum for examples of editing models but I haven’t had much luck finding exactly what I’m after. Most people seam to be interested in changing a models structure and not in duplicating weights.

However I have made some progress, after you pointed me in the right direction I realised I could see the weights of the first layer with

list(learn.model.parameters())[0]

and if I zoom in further with

list(learn.model.parameters())[0][0]

I get this

tensor([[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  7.4841e-02,  5.6615e-02,
           1.7083e-02, -1.2694e-02],
         [ 1.1083e-02,  9.5276e-03, -1.0993e-01, -2.8050e-01, -2.7124e-01,
          -1.2907e-01,  3.7424e-03],
         [-6.9434e-03,  5.9089e-02,  2.9548e-01,  5.8720e-01,  5.1972e-01,
           2.5632e-01,  6.3573e-02],
         [ 3.0505e-02, -6.7018e-02, -2.9841e-01, -4.3868e-01, -2.7085e-01,
          -6.1282e-04,  5.7602e-02],
         [-2.7535e-02,  1.6045e-02,  7.2595e-02, -5.4102e-02, -3.3285e-01,
          -4.2058e-01, -2.5781e-01],
         [ 3.0613e-02,  4.0960e-02,  6.2850e-02,  2.3897e-01,  4.1384e-01,
           3.9359e-01,  1.6606e-01],
         [-1.3736e-02, -3.6746e-03, -2.4084e-02, -6.5877e-02, -1.5070e-01,
          -8.2230e-02, -5.7828e-03]],

        [[-1.1397e-02, -2.6619e-02, -3.4641e-02,  3.6812e-02,  3.2521e-02,
           6.6221e-04, -2.5743e-02],
         [ 4.5687e-02,  3.3603e-02, -1.0453e-01, -3.0885e-01, -3.1253e-01,
          -1.6051e-01, -1.2826e-03],
         [-8.3730e-04,  9.8420e-02,  4.0210e-01,  7.7035e-01,  7.0789e-01,
           3.6887e-01,  1.2455e-01],
         [-5.8427e-03, -1.2862e-01, -4.2071e-01, -5.9270e-01, -3.8285e-01,
          -4.2407e-02,  6.1568e-02],
         [-5.5926e-02, -5.2239e-03,  2.7081e-02, -1.5159e-01, -4.6178e-01,
          -5.7080e-01, -3.6552e-01],
         [ 3.2860e-02,  5.5574e-02,  9.9670e-02,  3.1815e-01,  5.4636e-01,
           4.8276e-01,  1.9867e-01],
         [ 5.3051e-03,  6.6938e-03, -1.7254e-02, -6.9806e-02, -1.4822e-01,
          -7.7248e-02,  7.2183e-04]],

        [[-2.0315e-03, -9.1617e-03,  2.1209e-02,  8.9755e-02,  8.9177e-02,
           3.3655e-02, -2.0102e-02],
         [ 1.5398e-02, -1.8648e-02, -1.2591e-01, -2.9553e-01, -2.5342e-01,
          -1.2980e-01, -2.7975e-02],
         [ 9.8454e-03,  4.9047e-02,  2.1699e-01,  4.3010e-01,  3.4872e-01,
           1.0433e-01,  1.8413e-02],
         [ 2.6426e-02, -2.5990e-02, -1.9699e-01, -2.6806e-01, -1.0524e-01,
           7.8577e-02,  1.2077e-01],
         [-2.8356e-02,  1.8404e-02,  9.8647e-02,  6.1242e-02, -1.1740e-01,
          -2.5760e-01, -1.5451e-01],
         [ 2.0766e-02, -2.6286e-03, -3.7825e-02,  5.7450e-02,  2.4141e-01,
           2.4345e-01,  1.1796e-01],
         [ 7.4684e-04,  7.7677e-04, -1.0050e-02, -5.5153e-02, -1.4865e-01,
          -1.1754e-01, -3.8350e-02]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]])

I see we have 6 tensors (one for each input band), the first 3 all loaded with numbers (I assume these are the pretrained weights and the last 3 are all filled with 0 values. This is the case with all the neighbouring tensors as well (list(learn.model.parameters())[0][1], list(learn.model.parameters())[0][2] ect).

So I figured I need to copy the values from the first 3 tensors into the last 3 tensors while also decrementing all the values by 50%. I believe I have done this with this code.

# grab a copy of the weights
weights = list(learn.model.parameters())

# total number of unput bands
band_count = 6

# the pretrained bands (RGB)
pretrained_bands = [0,1,2]

# this is the amount to divide each layer by to compensate for the additional bands
dev_amount = band_count/len(pretrained_bands)

# all bands beyond RGB
untrained_bands = range(len(pretrained_bands),band_count)

# only edit the index [0] of the weights, this is all we need to edit (I think...)
for band in weights[0]:
   
#     reduce all RGB bands by the dev amount
   for i in pretrained_bands:
       band[i] = band[i]/dev_amount
       
#     this will loop over the 'untrained_bands' assigning then the values of R then G then B
   for pretrained, untrained in enumerate(untrained_bands):
       
       pre = pretrained_bands[pretrained % len(pretrained_bands)]
       
       band[untrained] = band[pre]

This appears to work as now when I call

list(learn.model.parameters())[0][0]

I get this

tensor([[[-5.2097e-03, -3.0678e-03, -9.0489e-04,  3.7421e-02,  2.8307e-02,
           8.5417e-03, -6.3469e-03],
         [ 5.5414e-03,  4.7638e-03, -5.4963e-02, -1.4025e-01, -1.3562e-01,
          -6.4537e-02,  1.8712e-03],
         [-3.4717e-03,  2.9544e-02,  1.4774e-01,  2.9360e-01,  2.5986e-01,
           1.2816e-01,  3.1786e-02],
         [ 1.5252e-02, -3.3509e-02, -1.4921e-01, -2.1934e-01, -1.3543e-01,
          -3.0641e-04,  2.8801e-02],
         [-1.3767e-02,  8.0225e-03,  3.6297e-02, -2.7051e-02, -1.6642e-01,
          -2.1029e-01, -1.2891e-01],
         [ 1.5306e-02,  2.0480e-02,  3.1425e-02,  1.1949e-01,  2.0692e-01,
           1.9679e-01,  8.3029e-02],
         [-6.8681e-03, -1.8373e-03, -1.2042e-02, -3.2939e-02, -7.5350e-02,
          -4.1115e-02, -2.8914e-03]],

        [[-5.6986e-03, -1.3310e-02, -1.7320e-02,  1.8406e-02,  1.6260e-02,
           3.3110e-04, -1.2872e-02],
         [ 2.2843e-02,  1.6802e-02, -5.2264e-02, -1.5443e-01, -1.5626e-01,
          -8.0253e-02, -6.4128e-04],
         [-4.1865e-04,  4.9210e-02,  2.0105e-01,  3.8517e-01,  3.5394e-01,
           1.8444e-01,  6.2277e-02],
         [-2.9214e-03, -6.4308e-02, -2.1035e-01, -2.9635e-01, -1.9142e-01,
          -2.1203e-02,  3.0784e-02],
         [-2.7963e-02, -2.6120e-03,  1.3540e-02, -7.5793e-02, -2.3089e-01,
          -2.8540e-01, -1.8276e-01],
         [ 1.6430e-02,  2.7787e-02,  4.9835e-02,  1.5907e-01,  2.7318e-01,
           2.4138e-01,  9.9337e-02],
         [ 2.6526e-03,  3.3469e-03, -8.6272e-03, -3.4903e-02, -7.4112e-02,
          -3.8624e-02,  3.6091e-04]],

        [[-1.0158e-03, -4.5808e-03,  1.0605e-02,  4.4878e-02,  4.4588e-02,
           1.6828e-02, -1.0051e-02],
         [ 7.6991e-03, -9.3240e-03, -6.2954e-02, -1.4776e-01, -1.2671e-01,
          -6.4900e-02, -1.3988e-02],
         [ 4.9227e-03,  2.4523e-02,  1.0850e-01,  2.1505e-01,  1.7436e-01,
           5.2166e-02,  9.2064e-03],
         [ 1.3213e-02, -1.2995e-02, -9.8493e-02, -1.3403e-01, -5.2620e-02,
           3.9289e-02,  6.0387e-02],
         [-1.4178e-02,  9.2021e-03,  4.9324e-02,  3.0621e-02, -5.8700e-02,
          -1.2880e-01, -7.7253e-02],
         [ 1.0383e-02, -1.3143e-03, -1.8912e-02,  2.8725e-02,  1.2071e-01,
           1.2172e-01,  5.8979e-02],
         [ 3.7342e-04,  3.8839e-04, -5.0251e-03, -2.7576e-02, -7.4325e-02,
          -5.8768e-02, -1.9175e-02]],

        [[-5.2097e-03, -3.0678e-03, -9.0489e-04,  3.7421e-02,  2.8307e-02,
           8.5417e-03, -6.3469e-03],
         [ 5.5414e-03,  4.7638e-03, -5.4963e-02, -1.4025e-01, -1.3562e-01,
          -6.4537e-02,  1.8712e-03],
         [-3.4717e-03,  2.9544e-02,  1.4774e-01,  2.9360e-01,  2.5986e-01,
           1.2816e-01,  3.1786e-02],
         [ 1.5252e-02, -3.3509e-02, -1.4921e-01, -2.1934e-01, -1.3543e-01,
          -3.0641e-04,  2.8801e-02],
         [-1.3767e-02,  8.0225e-03,  3.6297e-02, -2.7051e-02, -1.6642e-01,
          -2.1029e-01, -1.2891e-01],
         [ 1.5306e-02,  2.0480e-02,  3.1425e-02,  1.1949e-01,  2.0692e-01,
           1.9679e-01,  8.3029e-02],
         [-6.8681e-03, -1.8373e-03, -1.2042e-02, -3.2939e-02, -7.5350e-02,
          -4.1115e-02, -2.8914e-03]],

        [[-5.6986e-03, -1.3310e-02, -1.7320e-02,  1.8406e-02,  1.6260e-02,
           3.3110e-04, -1.2872e-02],
         [ 2.2843e-02,  1.6802e-02, -5.2264e-02, -1.5443e-01, -1.5626e-01,
          -8.0253e-02, -6.4128e-04],
         [-4.1865e-04,  4.9210e-02,  2.0105e-01,  3.8517e-01,  3.5394e-01,
           1.8444e-01,  6.2277e-02],
         [-2.9214e-03, -6.4308e-02, -2.1035e-01, -2.9635e-01, -1.9142e-01,
          -2.1203e-02,  3.0784e-02],
         [-2.7963e-02, -2.6120e-03,  1.3540e-02, -7.5793e-02, -2.3089e-01,
          -2.8540e-01, -1.8276e-01],
         [ 1.6430e-02,  2.7787e-02,  4.9835e-02,  1.5907e-01,  2.7318e-01,
           2.4138e-01,  9.9337e-02],
         [ 2.6526e-03,  3.3469e-03, -8.6272e-03, -3.4903e-02, -7.4112e-02,
          -3.8624e-02,  3.6091e-04]],

        [[-1.0158e-03, -4.5808e-03,  1.0605e-02,  4.4878e-02,  4.4588e-02,
           1.6828e-02, -1.0051e-02],
         [ 7.6991e-03, -9.3240e-03, -6.2954e-02, -1.4776e-01, -1.2671e-01,
          -6.4900e-02, -1.3988e-02],
         [ 4.9227e-03,  2.4523e-02,  1.0850e-01,  2.1505e-01,  1.7436e-01,
           5.2166e-02,  9.2064e-03],
         [ 1.3213e-02, -1.2995e-02, -9.8493e-02, -1.3403e-01, -5.2620e-02,
           3.9289e-02,  6.0387e-02],
         [-1.4178e-02,  9.2021e-03,  4.9324e-02,  3.0621e-02, -5.8700e-02,
          -1.2880e-01, -7.7253e-02],
         [ 1.0383e-02, -1.3143e-03, -1.8912e-02,  2.8725e-02,  1.2071e-01,
           1.2172e-01,  5.8979e-02],
         [ 3.7342e-04,  3.8839e-04, -5.0251e-03, -2.7576e-02, -7.4325e-02,
          -5.8768e-02, -1.9175e-02]]])

Which looks correct (I think), but I still have a question,

How is learn.model.parameters() getting a copy of my edited weights? I had expected I would need to pass my ‘weights’ variable back to learn.model somehow.

Any comments or suggestions would be appreciated.

Thanks :+1:

Update: so I tried feeding in some really high numbers into the weights and it totally destroyed the training, so it looks like the model is using my edited weights.

1 Like

Hi Nick,

Good job on copying the weights! IMHO, the only way to learn Python, PyTorch, and fastai thoroughly is to struggle with the internals - to get your hands dirty, do lots experiments and make lots of mistakes. You are on the way.

How is learn.model.parameters() getting a copy of my edited weights? I had expected I would need to pass my ‘weights’ variable back to learn.model somehow.

learn.model.parameters() generates references to all the trainable tensors in learn.model. If you change a tensor listed by model.parameters(), it is the very same tensor as in the model.

I’ll offer you some tips and improvements to your code. If you feel this is obnoxiously intrusive, just let me know.

  • Predict first in your mind what you expect to happen.

  • Try it!

  • Check your prediction against what actually happened. This way you build up an internal model Python/PyTorch/fastai.

  • Lots of things happen invisibly in PyTorch and fastai when you do something that seems entirely unrelated. For example, tracking of model parameters and autograd when simply assigning a variable. This programming technique can be very confusing! Eventually you’ll learn what happens behind the scenes.

  • Whenever you start to write a loop in PyTorch, you can probably find a way to do it without a loop. Here, slicing and broadcasting are your tools. Slicing selects a rectangular sub-tensor of a given tensor, without costing memory by building a new object. Broadcasting duplicates any dimension 1 to match a dimension of any size, again without constructing a new object.

  • Check and verify every operation. It’s extremely hard to grasp and debug what goes on inside a gazillion parameter model being trained. When the model does not learn, is the problem that your model is not competent, or that your setup is wrong? So I have found it useful to make absolutely sure the inputs are correct and that the outputs at least make sense.

As for your code, learn.model.parameters() will give you all the model’s parameters in a long list, but how do you know which is which? You are just guessing that you have found the first layer’s conv2d. Better is to go directly to the model:

layer1 = learn.model.layers[0][0]
print(layer1) 
l1weights = layer1.weight
l1weights.shape

Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
torch.Size([64, 6, 7, 7])

Yes, that’s the one. Let’s make sure the weights are filled in as expected…

l1weights[:,:3] #The first three channels are all numbers
l1weights[:,3:] #The last three channels are all zeros

Now copy the first three channels into the second three, using slices and no loops:
l1weights[:,3:] = l1weights[:,:3]

Check it!

Finally, halve all the weights using broadcasting:
l1weights *= .5

Note that you are now obligated to provide data to the second three channels in RGB order, and that to continue training layer 1 you will eventually need to unfreeze.

Update: so I tried feeding in some really high numbers into the weights and it totally destroyed the training, so it looks like the model is using my edited weights.

That’s a valuable sanity check.

Please let me know if this explanation is helpful. I like that you are willing to do experiments and learn from them.
:slightly_smiling_face:

2 Likes

Hi Malcolm,

Thanks for explaining how the model was still seeing the weights, that makes sense.

Also thanks for the tips, I will keep them in mind :+1:

The reason I was performing that operation in a loop was in part because I wanted it to work for an arbitrary band count, but after your example and some more digging I have found a tensor friendly way.

layer1 = learn.model[0][0]
print(layer1)
l1_weights = layer1.weight
l1_weights.shape

# Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# torch.Size([64, 6, 7, 7])
import math
# we need to duplicate the RGB weights by an arbitrary count, but we can only increase it by 3 at a time with 'repeat'
# so we work out how many times we need to repeat it, round that number up, repeat the weights that many times
# then chop off any excess.

pretrained_band_count = 3

band_ratio = band_count/pretrained_band_count

# round up
repeat_count = math.ceil(band_ratio)
   
#           RGB weights                         repeat on 2nd axis   chops off any excess 
l1weights = l1weights[:,:pretrained_band_count].repeat(1,repeat_count,1,1)[:,:band_count]

# rescale weights by band_ratio
l1weights = l1weights / band_ratio

Thanks again Malcolm, your explanations have been very clear and having someone to bounce this back and forth with has been very helpful.

Cheers :beers:

For some closure I just completed an end to end walk through using this solution Multispectral image classification with Transfer Learning

1 Like

I was wondering about exactly this problem and then discovered your post here. A question that I didn’t see discussed in your blog post: did you get better results after you started modifying your weights to be non-zero for your higher channels?

Also, I found that the weights were not updating for me unless I did the following:

# (... your code)

# rescale weights by channel_ratio
l1_weights = l1_weights / channel_ratio

# Make sure the layer1 weights actually get updated
layer1.weight.data = l1_weights.data

Basically, I tested whether the 4th channel was getting updated. If I didn’t set layer1.weight.data directly, I found that the 4th channel remained all zeroes. But by updating it with the line above, the 4th channel did copy the 1st channel’s weights.

# Print the first and 4th channel's weights
print(learn.model[0][0].weight[1][0])
print(learn.model[0][0].weight[1][3])
1 Like