Training resnet model then using it for transfer learning as if it were a built in model with pretrained weights

I found a forum post with a very similar problem to mine:

Technically, someone does give a solution and it does appear to work, however I suspect it may not be the recommended solution for this particular problem (at least I hope not).

In my specific scenario I am training a resnet model using cnn_learner starting with the pretrained weights, however my dataset is ultrasound images. Even though the pretrained model is a good starting point for training, as you can imagine my final model takes a significant amount of training time to get to the best accuracy. So what I would like to do is save that model somewhere and use it like it was a pretrained resnet model I started with again using cnn_learner. Is there any way to save this model and use it as if it were any other built in pretrained model?

So I figured this out for my issue.

def my_resnet(x):
    model = torch.load("my_trained_pytorch_resnet_model")
    all_layers = list(model.children())
    return nn.Sequential(*all_layers[0], *all_layers[1:])

learn = cnn_learner(data, 
                    split_on=lambda m: (m[0][6],m[1]))

@rbunn80130 I’m trying to do something similar to this.

Did you get this to work or have a notebook somewhere? I’m super curious about how your resolution worked.

other than what I posted above, not really.

@muellerzr - do you have any good approaches to this problem?

I do! This notebook about 3/4 of the way down:

It discusses using custom pkl weights for your model. The function itself is transfer_learn

1 Like

Thanks! That’s exactly what I was looking for. You sir, are a genius.

It appears that if some linear layers on the head happen to have the same shape as before then those values will be copied over as well. How would you go about resetting the head even if it happened to have the same shape?

For now, I just enumerated the items and stopped copying when the head is reached, but that isn’t a good solution in general obviously.

You could look into init_cnn and see how fastai initializes their models. Basically you would need to reinitialize that last layer

1 Like