Essentially you copy the state dict for everything but the output classes (and input if that’s what you want) Here is one that I used: Loading pretrained weights that are not from ImageNet
def load_diff_pretrained(learn, name:Union[Path,str], device:torch.device=None):
"Load model `name` from `self.model_dir` using `device`, defaulting to `self.data.device`."
if device is None: device = learn.data.device
if (learn.model_dir/name).with_suffix('.pth').exists(): model_path = (learn.model_dir/name).with_suffix('.pth')
else: model_path = name
new_state_dict = torch.load(model_path, map_location=device)
learn_state_dict = learn.model.state_dict()
for name, param in learn_state_dict.items():
if name in new_state_dict:
input_param = new_state_dict[name]
if input_param.shape == param.shape:
param.copy_(input_param)
else:
print('Shape mismatch at:', name, 'skipping')
else:
print(f'{name} weight of the model not in pretrained weights')
learn.model.load_state_dict(learn_state_dict)