Custom pretrained encoder in unet

Hi,

I am trying to load a custom pretrained encoder part of a unet, instead of the ImageNet resnet34 weights. This is a very interesting way of doing transfer learning.

I am using the following code to do this:

(...)
learn = unet_learner(data, models.resnet34, pretrained=False)
state = torch.load(path / 'custom_pretrained_resnet34.pth' )
learn.model[0].load_state_dict(state['model'], strict=False) 

Is it OK?
Is there a better way to do this?
Many thanks!!!

Hi Imanol,
If your approach works, I guess it is OK :).

I usually use learn.load('...') after I have defined my learner and want to use a different model, I guess that would work too in your case but I’m not sure.

1 Like

It doesn’t work. If I randomly initialize the encoder, I get similar results. In fact, if I use the strict=True, it shows errors.

However, this worked better:

Save the state_dict of the base of the encoder. We do not save weights of the head of the model:

encoder_path = path/‘resnet34_encoder.h5’
torch.save(encoder_learn.model[0].state_dict(), encoder_path)

Load it as the unet encoder (this is another notebook):

unet_learn.model[0].load_state_dict(torch.load(encoder_path), strict=True)

2 Likes