Thanks Malcolm,
Ok that sounds reasonable, I wasn’t having much luck trying to used @patch anyway haha .
I have been having a look around this forum and the PyTorch forum for examples of editing models but I haven’t had much luck finding exactly what I’m after. Most people seam to be interested in changing a models structure and not in duplicating weights.
However I have made some progress, after you pointed me in the right direction I realised I could see the weights of the first layer with
list(learn.model.parameters())[0]
and if I zoom in further with
list(learn.model.parameters())[0][0]
I get this
tensor([[[-1.0419e-02, -6.1356e-03, -1.8098e-03, 7.4841e-02, 5.6615e-02,
1.7083e-02, -1.2694e-02],
[ 1.1083e-02, 9.5276e-03, -1.0993e-01, -2.8050e-01, -2.7124e-01,
-1.2907e-01, 3.7424e-03],
[-6.9434e-03, 5.9089e-02, 2.9548e-01, 5.8720e-01, 5.1972e-01,
2.5632e-01, 6.3573e-02],
[ 3.0505e-02, -6.7018e-02, -2.9841e-01, -4.3868e-01, -2.7085e-01,
-6.1282e-04, 5.7602e-02],
[-2.7535e-02, 1.6045e-02, 7.2595e-02, -5.4102e-02, -3.3285e-01,
-4.2058e-01, -2.5781e-01],
[ 3.0613e-02, 4.0960e-02, 6.2850e-02, 2.3897e-01, 4.1384e-01,
3.9359e-01, 1.6606e-01],
[-1.3736e-02, -3.6746e-03, -2.4084e-02, -6.5877e-02, -1.5070e-01,
-8.2230e-02, -5.7828e-03]],
[[-1.1397e-02, -2.6619e-02, -3.4641e-02, 3.6812e-02, 3.2521e-02,
6.6221e-04, -2.5743e-02],
[ 4.5687e-02, 3.3603e-02, -1.0453e-01, -3.0885e-01, -3.1253e-01,
-1.6051e-01, -1.2826e-03],
[-8.3730e-04, 9.8420e-02, 4.0210e-01, 7.7035e-01, 7.0789e-01,
3.6887e-01, 1.2455e-01],
[-5.8427e-03, -1.2862e-01, -4.2071e-01, -5.9270e-01, -3.8285e-01,
-4.2407e-02, 6.1568e-02],
[-5.5926e-02, -5.2239e-03, 2.7081e-02, -1.5159e-01, -4.6178e-01,
-5.7080e-01, -3.6552e-01],
[ 3.2860e-02, 5.5574e-02, 9.9670e-02, 3.1815e-01, 5.4636e-01,
4.8276e-01, 1.9867e-01],
[ 5.3051e-03, 6.6938e-03, -1.7254e-02, -6.9806e-02, -1.4822e-01,
-7.7248e-02, 7.2183e-04]],
[[-2.0315e-03, -9.1617e-03, 2.1209e-02, 8.9755e-02, 8.9177e-02,
3.3655e-02, -2.0102e-02],
[ 1.5398e-02, -1.8648e-02, -1.2591e-01, -2.9553e-01, -2.5342e-01,
-1.2980e-01, -2.7975e-02],
[ 9.8454e-03, 4.9047e-02, 2.1699e-01, 4.3010e-01, 3.4872e-01,
1.0433e-01, 1.8413e-02],
[ 2.6426e-02, -2.5990e-02, -1.9699e-01, -2.6806e-01, -1.0524e-01,
7.8577e-02, 1.2077e-01],
[-2.8356e-02, 1.8404e-02, 9.8647e-02, 6.1242e-02, -1.1740e-01,
-2.5760e-01, -1.5451e-01],
[ 2.0766e-02, -2.6286e-03, -3.7825e-02, 5.7450e-02, 2.4141e-01,
2.4345e-01, 1.1796e-01],
[ 7.4684e-04, 7.7677e-04, -1.0050e-02, -5.5153e-02, -1.4865e-01,
-1.1754e-01, -3.8350e-02]],
[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00]]])
I see we have 6 tensors (one for each input band), the first 3 all loaded with numbers (I assume these are the pretrained weights and the last 3 are all filled with 0 values. This is the case with all the neighbouring tensors as well (list(learn.model.parameters())[0][1], list(learn.model.parameters())[0][2] ect).
So I figured I need to copy the values from the first 3 tensors into the last 3 tensors while also decrementing all the values by 50%. I believe I have done this with this code.
# grab a copy of the weights
weights = list(learn.model.parameters())
# total number of unput bands
band_count = 6
# the pretrained bands (RGB)
pretrained_bands = [0,1,2]
# this is the amount to divide each layer by to compensate for the additional bands
dev_amount = band_count/len(pretrained_bands)
# all bands beyond RGB
untrained_bands = range(len(pretrained_bands),band_count)
# only edit the index [0] of the weights, this is all we need to edit (I think...)
for band in weights[0]:
# reduce all RGB bands by the dev amount
for i in pretrained_bands:
band[i] = band[i]/dev_amount
# this will loop over the 'untrained_bands' assigning then the values of R then G then B
for pretrained, untrained in enumerate(untrained_bands):
pre = pretrained_bands[pretrained % len(pretrained_bands)]
band[untrained] = band[pre]
This appears to work as now when I call
list(learn.model.parameters())[0][0]
I get this
tensor([[[-5.2097e-03, -3.0678e-03, -9.0489e-04, 3.7421e-02, 2.8307e-02,
8.5417e-03, -6.3469e-03],
[ 5.5414e-03, 4.7638e-03, -5.4963e-02, -1.4025e-01, -1.3562e-01,
-6.4537e-02, 1.8712e-03],
[-3.4717e-03, 2.9544e-02, 1.4774e-01, 2.9360e-01, 2.5986e-01,
1.2816e-01, 3.1786e-02],
[ 1.5252e-02, -3.3509e-02, -1.4921e-01, -2.1934e-01, -1.3543e-01,
-3.0641e-04, 2.8801e-02],
[-1.3767e-02, 8.0225e-03, 3.6297e-02, -2.7051e-02, -1.6642e-01,
-2.1029e-01, -1.2891e-01],
[ 1.5306e-02, 2.0480e-02, 3.1425e-02, 1.1949e-01, 2.0692e-01,
1.9679e-01, 8.3029e-02],
[-6.8681e-03, -1.8373e-03, -1.2042e-02, -3.2939e-02, -7.5350e-02,
-4.1115e-02, -2.8914e-03]],
[[-5.6986e-03, -1.3310e-02, -1.7320e-02, 1.8406e-02, 1.6260e-02,
3.3110e-04, -1.2872e-02],
[ 2.2843e-02, 1.6802e-02, -5.2264e-02, -1.5443e-01, -1.5626e-01,
-8.0253e-02, -6.4128e-04],
[-4.1865e-04, 4.9210e-02, 2.0105e-01, 3.8517e-01, 3.5394e-01,
1.8444e-01, 6.2277e-02],
[-2.9214e-03, -6.4308e-02, -2.1035e-01, -2.9635e-01, -1.9142e-01,
-2.1203e-02, 3.0784e-02],
[-2.7963e-02, -2.6120e-03, 1.3540e-02, -7.5793e-02, -2.3089e-01,
-2.8540e-01, -1.8276e-01],
[ 1.6430e-02, 2.7787e-02, 4.9835e-02, 1.5907e-01, 2.7318e-01,
2.4138e-01, 9.9337e-02],
[ 2.6526e-03, 3.3469e-03, -8.6272e-03, -3.4903e-02, -7.4112e-02,
-3.8624e-02, 3.6091e-04]],
[[-1.0158e-03, -4.5808e-03, 1.0605e-02, 4.4878e-02, 4.4588e-02,
1.6828e-02, -1.0051e-02],
[ 7.6991e-03, -9.3240e-03, -6.2954e-02, -1.4776e-01, -1.2671e-01,
-6.4900e-02, -1.3988e-02],
[ 4.9227e-03, 2.4523e-02, 1.0850e-01, 2.1505e-01, 1.7436e-01,
5.2166e-02, 9.2064e-03],
[ 1.3213e-02, -1.2995e-02, -9.8493e-02, -1.3403e-01, -5.2620e-02,
3.9289e-02, 6.0387e-02],
[-1.4178e-02, 9.2021e-03, 4.9324e-02, 3.0621e-02, -5.8700e-02,
-1.2880e-01, -7.7253e-02],
[ 1.0383e-02, -1.3143e-03, -1.8912e-02, 2.8725e-02, 1.2071e-01,
1.2172e-01, 5.8979e-02],
[ 3.7342e-04, 3.8839e-04, -5.0251e-03, -2.7576e-02, -7.4325e-02,
-5.8768e-02, -1.9175e-02]],
[[-5.2097e-03, -3.0678e-03, -9.0489e-04, 3.7421e-02, 2.8307e-02,
8.5417e-03, -6.3469e-03],
[ 5.5414e-03, 4.7638e-03, -5.4963e-02, -1.4025e-01, -1.3562e-01,
-6.4537e-02, 1.8712e-03],
[-3.4717e-03, 2.9544e-02, 1.4774e-01, 2.9360e-01, 2.5986e-01,
1.2816e-01, 3.1786e-02],
[ 1.5252e-02, -3.3509e-02, -1.4921e-01, -2.1934e-01, -1.3543e-01,
-3.0641e-04, 2.8801e-02],
[-1.3767e-02, 8.0225e-03, 3.6297e-02, -2.7051e-02, -1.6642e-01,
-2.1029e-01, -1.2891e-01],
[ 1.5306e-02, 2.0480e-02, 3.1425e-02, 1.1949e-01, 2.0692e-01,
1.9679e-01, 8.3029e-02],
[-6.8681e-03, -1.8373e-03, -1.2042e-02, -3.2939e-02, -7.5350e-02,
-4.1115e-02, -2.8914e-03]],
[[-5.6986e-03, -1.3310e-02, -1.7320e-02, 1.8406e-02, 1.6260e-02,
3.3110e-04, -1.2872e-02],
[ 2.2843e-02, 1.6802e-02, -5.2264e-02, -1.5443e-01, -1.5626e-01,
-8.0253e-02, -6.4128e-04],
[-4.1865e-04, 4.9210e-02, 2.0105e-01, 3.8517e-01, 3.5394e-01,
1.8444e-01, 6.2277e-02],
[-2.9214e-03, -6.4308e-02, -2.1035e-01, -2.9635e-01, -1.9142e-01,
-2.1203e-02, 3.0784e-02],
[-2.7963e-02, -2.6120e-03, 1.3540e-02, -7.5793e-02, -2.3089e-01,
-2.8540e-01, -1.8276e-01],
[ 1.6430e-02, 2.7787e-02, 4.9835e-02, 1.5907e-01, 2.7318e-01,
2.4138e-01, 9.9337e-02],
[ 2.6526e-03, 3.3469e-03, -8.6272e-03, -3.4903e-02, -7.4112e-02,
-3.8624e-02, 3.6091e-04]],
[[-1.0158e-03, -4.5808e-03, 1.0605e-02, 4.4878e-02, 4.4588e-02,
1.6828e-02, -1.0051e-02],
[ 7.6991e-03, -9.3240e-03, -6.2954e-02, -1.4776e-01, -1.2671e-01,
-6.4900e-02, -1.3988e-02],
[ 4.9227e-03, 2.4523e-02, 1.0850e-01, 2.1505e-01, 1.7436e-01,
5.2166e-02, 9.2064e-03],
[ 1.3213e-02, -1.2995e-02, -9.8493e-02, -1.3403e-01, -5.2620e-02,
3.9289e-02, 6.0387e-02],
[-1.4178e-02, 9.2021e-03, 4.9324e-02, 3.0621e-02, -5.8700e-02,
-1.2880e-01, -7.7253e-02],
[ 1.0383e-02, -1.3143e-03, -1.8912e-02, 2.8725e-02, 1.2071e-01,
1.2172e-01, 5.8979e-02],
[ 3.7342e-04, 3.8839e-04, -5.0251e-03, -2.7576e-02, -7.4325e-02,
-5.8768e-02, -1.9175e-02]]])
Which looks correct (I think), but I still have a question,
How is learn.model.parameters() getting a copy of my edited weights? I had expected I would need to pass my ‘weights’ variable back to learn.model somehow.
Any comments or suggestions would be appreciated.
Thanks
Update: so I tried feeding in some really high numbers into the weights and it totally destroyed the training, so it looks like the model is using my edited weights.