I wanted to load weights to a Learner object that I pretrained using a different dataset. Let’s assume that the number of classes in the new dataset and pretrained dataset are different.
Running learn.load('weights') does not work as there is a size mismatch at the final layer if the number of classes in the current dataset and pretrained dataset are different.
I decided to try and modify create_cnn to allow it to take a path to pretrained network weights and return a Learner object that has the pretrained weights.
To do this, I looked into what create_cnn is doing under the hood.
I summarised the steps as follows:
- Call
cnn_config()function to getmetainformation aboutarch. - Use the
metainformation to create abodyand a newheadusingcreate_body()andcreate_head()respectively - A new
modelis created by puttingbodyandheadsequentially. -
ClassificationLearneris then used to return aLearnerobject -
learn.split()seems to split thelearn.modelat the defined split. Functionally, this seems to split themodelintoheadandbody. - If pretrained, freeze up to the last layer
- Initialize the new head with
nn.init.kaiming_normal_.
I had a question about step 7 above. apply_init(model[1], nn.init.kaiming_normal_) is used to initialise the model. How is this initialising learn's model? Is my understanding of the steps right?
Anyway, I wrote a function to load a different pretrained model closely modelled on the create_cnn function: create_cnn_from_pretrained. I added a new step in between steps 1 and 2 where I load the weights of the pretrained model. The rest is mostly the same.
def create_cnn_from_pretrained(data:DataBunch, arch:Callable, pretrained_model:PathOrStr, cut:Union[int,Callable]=None,
lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,
custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None,
classification:bool=True, **kwargs:Any)->None:
"Build convnet style learners and load pretrained model from other learn model"
meta = cnn_config(arch)
arch = arch(pretrained=False)
arch.load_state_dict(torch.load(pretrained_model, map_location=data.device), strict=False)
body = create_body(arch, ifnone(cut,meta['cut']))
nf = num_features_model(body) * 2
head = custom_head or create_head(nf, data.c, lin_ftrs, ps)
model = nn.Sequential(body, head)
learn = ClassificationLearner(data, model, **kwargs)
learn.split(ifnone(split_on,meta['split']))
learn.freeze()
apply_init(model[1], nn.init.kaiming_normal_)
return learn
Would appreciate if others could confirm that this makes sense and if it would be something that could be added to the library itself?
EDIT: This method does not work. Setting strict=False does not solve the problem of the mismatch of weight sizes. When strict=False PyTorch loads weights that have the same names not different sizes. I have a slightly different workaround below.