Loading pretrained weights that are not from ImageNet

I wanted to load weights to a Learner object that I pretrained using a different dataset. Let’s assume that the number of classes in the new dataset and pretrained dataset are different.

Running learn.load('weights') does not work as there is a size mismatch at the final layer if the number of classes in the current dataset and pretrained dataset are different.

I decided to try and modify create_cnn to allow it to take a path to pretrained network weights and return a Learner object that has the pretrained weights.

To do this, I looked into what create_cnn is doing under the hood.

I summarised the steps as follows:

  1. Call cnn_config() function to get meta information about arch.
  2. Use the meta information to create a body and a new head using create_body() and create_head() respectively
  3. A new model is created by putting body and head sequentially.
  4. ClassificationLearner is then used to return a Learner object
  5. learn.split() seems to split the learn.model at the defined split. Functionally, this seems to split the model into head and body .
  6. If pretrained, freeze up to the last layer
  7. Initialize the new head with nn.init.kaiming_normal_.

I had a question about step 7 above. apply_init(model[1], nn.init.kaiming_normal_) is used to initialise the model. How is this initialising learn's model? Is my understanding of the steps right?

Anyway, I wrote a function to load a different pretrained model closely modelled on the create_cnn function: create_cnn_from_pretrained. I added a new step in between steps 1 and 2 where I load the weights of the pretrained model. The rest is mostly the same.

def create_cnn_from_pretrained(data:DataBunch, arch:Callable, pretrained_model:PathOrStr, cut:Union[int,Callable]=None,
                lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,
                custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None,
                classification:bool=True, **kwargs:Any)->None:
    "Build convnet style learners and load pretrained model from other learn model"
    meta = cnn_config(arch)
    arch = arch(pretrained=False)
    arch.load_state_dict(torch.load(pretrained_model, map_location=data.device), strict=False)
    body = create_body(arch, ifnone(cut,meta['cut']))
    nf = num_features_model(body) * 2
    head = custom_head or create_head(nf, data.c, lin_ftrs, ps)
    model = nn.Sequential(body, head)
    learn = ClassificationLearner(data, model, **kwargs)
    learn.split(ifnone(split_on,meta['split']))
    learn.freeze()
    apply_init(model[1], nn.init.kaiming_normal_)
    return learn

Would appreciate if others could confirm that this makes sense and if it would be something that could be added to the library itself?

EDIT: This method does not work. Setting strict=False does not solve the problem of the mismatch of weight sizes. When strict=False PyTorch loads weights that have the same names not different sizes. I have a slightly different workaround below.

11 Likes

Yes that makes sense. The best approach I think would be to adjust the meaning of the pretrained parameter to create_cnn to it can be any of:

  • False: random weights
  • True: default imagenet weights
  • str: load weights from a URL (cached locally)
7 Likes

This method does not work. Setting strict=False does not solve the problem of the mismatch of weight sizes. When strict=False PyTorch loads weights that have the same names and same sizes. A shape mismatch error is thrown for weights with same names but different sizes.

I have a slightly different workaround: I wrote a function similar to learn.load and closer in vein to how PyTorch implements model.load_state_dict().

def load_diff_pretrained(learn, name:Union[Path,str], device:torch.device=None):
    "Load model `name` from `self.model_dir` using `device`, defaulting to `self.data.device`."
    if device is None: device = learn.data.device
    if (learn.model_dir/name).with_suffix('.pth').exists(): model_path = (learn.model_dir/name).with_suffix('.pth')
    else: model_path = name
    new_state_dict = torch.load(model_path, map_location=device)
    learn_state_dict = learn.model.state_dict()
    for name, param in learn_state_dict.items():
        if name in new_state_dict:
            input_param = new_state_dict[name]
            if input_param.shape == param.shape:
                param.copy_(input_param)
            else:
                print('Shape mismatch at:', name, 'skipping')
        else:
            print(f'{name} weight of the model not in pretrained weights')
    learn.model.load_state_dict(learn_state_dict)

Not sure how this could be incorporated into the library.

5 Likes

@viraat thank you very much for this! I’m working on an implementation where I want to do transfer learning and I was having issues loading the weights, and this fixed it perfectly for me! Thank you very much!

3 Likes

Hi @muellerzr, can you show me how to do it, please?

This seems to work for me:

learn = cnn_learner(data, models.resnet34, metrics=error_rate)
load_diff_pretrained(learn, 'weigths_from_another_learner')

But all this really seems to do is throw away mismatching classes (like @viraat said strict=False does). So in my case I’m trying to go from one dataset with 60 classes to another dataset with 3 classes and basically it seems it’s kind of useless:

0.0.weight weight of the model not in pretrained weights
0.1.weight weight of the model not in pretrained weights
0.1.bias weight of the model not in pretrained weights
0.1.running_mean weight of the model not in pretrained weights
0.1.running_var weight of the model not in pretrained weights
0.1.num_batches_tracked weight of the model not in pretrained weights
0.4.0.conv1.weight weight of the model not in pretrained weights
0.4.0.bn1.weight weight of the model not in pretrained weights
0.4.0.bn1.bias weight of the model not in pretrained weights
0.4.0.bn1.running_mean weight of the model not in pretrained weights
0.4.0.bn1.running_var weight of the model not in pretrained weights
0.4.0.bn1.num_batches_tracked weight of the model not in pretrained weights
0.4.0.conv2.weight weight of the model not in pretrained weights
0.4.0.bn2.weight weight of the model not in pretrained weights
0.4.0.bn2.bias weight of the model not in pretrained weights
0.4.0.bn2.running_mean weight of the model not in pretrained weights
[...truncated]

I’m assuming there’s another usecase where this does work. Like similar amount of classes and same shape (size) tensors, but I am unclear how strict=False differs. In fact the error_rate for my usecase is pretty much the same every time. Doesn’t matter if I use strict=False, load_diff_pretrained or just not use pretrained learners at all.

1 Like

Yeah, because cnn_learner seems to rename the model layers to fastai’ style, e.g. “0.4.0.bn1 …”, and these layers have different naming style in the pretrained weights. What did work for me is to load the weights before passing the model to a Learner, not cnn_learner. You can try it yourself.

1 Like

Using just a Learner won’t work for @KristerV as they indicate they are also changing the number of classes. So they need the functionality of cnn_learner to adapt the network head.
The easiest way is probably to map from original layer names to fastai layer names. Here is a notebook that does this. Note that the missing keys are just the layers in the fastai head added by cnn_learner, the rest of the state is loaded. May need testing and fixing for other models. In particular it assumes only Conv and BN layers need to be loaded to match names (should probably be true.for most vision models).
After doing this you can save the learner or new model state to load in future rather than having to do it every time.

5 Likes