I wanted to load weights to a
Learner object that I pretrained using a different dataset. Let’s assume that the number of classes in the new dataset and pretrained dataset are different.
learn.load('weights') does not work as there is a size mismatch at the final layer if the number of classes in the current dataset and pretrained dataset are different.
I decided to try and modify
create_cnn to allow it to take a path to pretrained network weights and return a
Learner object that has the pretrained weights.
To do this, I looked into what
create_cnn is doing under the hood.
I summarised the steps as follows:
cnn_config()function to get
- Use the
metainformation to create a
bodyand a new
- A new
modelis created by putting
ClassificationLearneris then used to return a
learn.split()seems to split the
learn.modelat the defined split. Functionally, this seems to split the
- If pretrained, freeze up to the last layer
- Initialize the new head with
I had a question about step 7 above.
apply_init(model, nn.init.kaiming_normal_) is used to initialise the model. How is this initialising
learn's model? Is my understanding of the steps right?
Anyway, I wrote a function to load a different pretrained model closely modelled on the
create_cnn_from_pretrained. I added a new step in between steps 1 and 2 where I load the weights of the pretrained model. The rest is mostly the same.
def create_cnn_from_pretrained(data:DataBunch, arch:Callable, pretrained_model:PathOrStr, cut:Union[int,Callable]=None, lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5, custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None, classification:bool=True, **kwargs:Any)->None: "Build convnet style learners and load pretrained model from other learn model" meta = cnn_config(arch) arch = arch(pretrained=False) arch.load_state_dict(torch.load(pretrained_model, map_location=data.device), strict=False) body = create_body(arch, ifnone(cut,meta['cut'])) nf = num_features_model(body) * 2 head = custom_head or create_head(nf, data.c, lin_ftrs, ps) model = nn.Sequential(body, head) learn = ClassificationLearner(data, model, **kwargs) learn.split(ifnone(split_on,meta['split'])) learn.freeze() apply_init(model, nn.init.kaiming_normal_) return learn
Would appreciate if others could confirm that this makes sense and if it would be something that could be added to the library itself?
EDIT: This method does not work. Setting
strict=False does not solve the problem of the mismatch of weight sizes. When
strict=False PyTorch loads weights that have the same names not different sizes. I have a slightly different workaround below.