Loading pretrained weights that are not from ImageNet

This method does not work. Setting strict=False does not solve the problem of the mismatch of weight sizes. When strict=False PyTorch loads weights that have the same names and same sizes. A shape mismatch error is thrown for weights with same names but different sizes.

I have a slightly different workaround: I wrote a function similar to learn.load and closer in vein to how PyTorch implements model.load_state_dict().

def load_diff_pretrained(learn, name:Union[Path,str], device:torch.device=None):
    "Load model `name` from `self.model_dir` using `device`, defaulting to `self.data.device`."
    if device is None: device = learn.data.device
    if (learn.model_dir/name).with_suffix('.pth').exists(): model_path = (learn.model_dir/name).with_suffix('.pth')
    else: model_path = name
    new_state_dict = torch.load(model_path, map_location=device)
    learn_state_dict = learn.model.state_dict()
    for name, param in learn_state_dict.items():
        if name in new_state_dict:
            input_param = new_state_dict[name]
            if input_param.shape == param.shape:
                param.copy_(input_param)
            else:
                print('Shape mismatch at:', name, 'skipping')
        else:
            print(f'{name} weight of the model not in pretrained weights')
    learn.model.load_state_dict(learn_state_dict)

Not sure how this could be incorporated into the library.

6 Likes