This method does not work. Setting strict=False
does not solve the problem of the mismatch of weight sizes. When strict=False
PyTorch loads weights that have the same names and same sizes. A shape mismatch error is thrown for weights with same names but different sizes.
I have a slightly different workaround: I wrote a function similar to learn.load
and closer in vein to how PyTorch implements model.load_state_dict()
.
def load_diff_pretrained(learn, name:Union[Path,str], device:torch.device=None):
"Load model `name` from `self.model_dir` using `device`, defaulting to `self.data.device`."
if device is None: device = learn.data.device
if (learn.model_dir/name).with_suffix('.pth').exists(): model_path = (learn.model_dir/name).with_suffix('.pth')
else: model_path = name
new_state_dict = torch.load(model_path, map_location=device)
learn_state_dict = learn.model.state_dict()
for name, param in learn_state_dict.items():
if name in new_state_dict:
input_param = new_state_dict[name]
if input_param.shape == param.shape:
param.copy_(input_param)
else:
print('Shape mismatch at:', name, 'skipping')
else:
print(f'{name} weight of the model not in pretrained weights')
learn.model.load_state_dict(learn_state_dict)
Not sure how this could be incorporated into the library.