I have trained a resnet50 on a database with 33 classes and now I want to load the weights to a problem of 3 classes.
When i try to load the weights by doing:
RuntimeError: Error(s) in loading state_dict for Sequential:
size mismatch for 1.8.weight: copying a param with shape torch.Size([33, 512]) from checkpoint, the shape in current model is torch.Size([3, 512]).
size mismatch for 1.8.bias: copying a param with shape torch.Size() from checkpoint, the shape in current model is torch.Size().
Shouldn’t strict = False handle this?
If i do
learn.load('my_model', strict=False, remove_module=True)
It works. But I don’t understand exactly what ‘remove_module=True’ does. I have looked at the documentation but I still don’t understand. Could someone please explain? Is this the right way of loading?
That’s the same question as here, but there’s no satisfactory answer.
remove_module just replaces
x.x in the sate dict (for instance
0.0.weight). I honestly am not sure how this affects it though.
For your problem, what I’d do is save the old model except the last layer (the one that has 33 outputs), and load it on your model without the last layer. Doing something like:
Obviously it depends a lot on your architecture. If you give me the output of
learn I might be able to find the exact command. If you use
cnn_learner, I think what I suggested works.
You can use the simple statement like below:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
learn = model(data, models.densenet121, callback_fns=ShowGraph)