I wanted to load weights to a Learner object that I pretrained using a different dataset. Let’s assume that the number of classes in the new dataset and pretrained dataset are different.
Running learn.load('weights') does not work as there is a size mismatch at the final layer if the number of classes in the current dataset and pretrained dataset are different.
I decided to try and modify create_cnn to allow it to take a path to pretrained network weights and return a Learner object that has the pretrained weights.
To do this, I looked into what create_cnn is doing under the hood.
I summarised the steps as follows:
- Call  cnn_config()function to getmetainformation aboutarch.
- Use the metainformation to create abodyand a newheadusingcreate_body()andcreate_head()respectively
- A new  modelis created by puttingbodyandheadsequentially.
- 
ClassificationLearneris then used to return aLearnerobject
- 
learn.split()seems to split thelearn.modelat the defined split. Functionally, this seems to split themodelintoheadandbody.
- If pretrained, freeze up to the last layer
- Initialize the new head with  nn.init.kaiming_normal_.
I had a question about step 7 above. apply_init(model[1], nn.init.kaiming_normal_) is used to initialise the model. How is this initialising learn's model? Is my understanding of the steps right?
Anyway, I wrote a function to load a different pretrained model closely modelled on the create_cnn function: create_cnn_from_pretrained. I added a new step in between steps 1 and 2 where I load the weights of the pretrained model. The rest is mostly the same.
def create_cnn_from_pretrained(data:DataBunch, arch:Callable, pretrained_model:PathOrStr, cut:Union[int,Callable]=None,
                lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,
                custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None,
                classification:bool=True, **kwargs:Any)->None:
    "Build convnet style learners and load pretrained model from other learn model"
    meta = cnn_config(arch)
    arch = arch(pretrained=False)
    arch.load_state_dict(torch.load(pretrained_model, map_location=data.device), strict=False)
    body = create_body(arch, ifnone(cut,meta['cut']))
    nf = num_features_model(body) * 2
    head = custom_head or create_head(nf, data.c, lin_ftrs, ps)
    model = nn.Sequential(body, head)
    learn = ClassificationLearner(data, model, **kwargs)
    learn.split(ifnone(split_on,meta['split']))
    learn.freeze()
    apply_init(model[1], nn.init.kaiming_normal_)
    return learn
Would appreciate if others could confirm that this makes sense and if it would be something that could be added to the library itself?
EDIT: This method does not work. Setting strict=False does not solve the problem of the mismatch of weight sizes. When strict=False PyTorch loads weights that have the same names not different sizes. I have a slightly different workaround below.