Issue with transfer learning (loading models in Pytorch)

So you see, I’m trying to train my model and I did so on a dataset with like 100+ classes. Now after that I attempted to fine-tune with a dataset with only 5 classes. Well, when I load the weights from the first trained model and load into the same model for transfer learning, I get the error where it says the number of classes aren’t the same and so the model can’t be loaded. Here is how it looks like:

RuntimeError: Error(s) in loading state_dict for hybrid_model:
size mismatch for classification_head.3.weight: copying a param with shape torch.Size([120, 256]) from checkpoint, the shape in current model is torch.Size([5, 256]).
size mismatch for classification_head.3.bias: copying a param with shape torch.Size([120]) from checkpoint, the shape in current model is torch.Size([5]).

Now, while loading I’m using strict = False, so non-matching layers be skipped but it still doesn’t work, keeps throwing the error you see above.

Any ideas on how to get around this issue?

Could you please provide all code that you are running?

Hello,
To resolve the class mismatch error, you can modify your model to accept a variable number of classes. One way to do this is by using a conditional layer that adapts to the number of classes during runtime.

Here’s a simplified approach: my points

  1. Modify your model’s classification head to handle a variable number of classes.
  2. Load the state_dict with strict=False to skip non-matching layers.

Best Regards,
Sonia Lewis

How to set variable no of classes? Can you give a short example?