I wanted to load weights to a Learner
object that I pretrained using a different dataset. Let’s assume that the number of classes in the new dataset and pretrained dataset are different.
Running learn.load('weights')
does not work as there is a size mismatch at the final layer if the number of classes in the current dataset and pretrained dataset are different.
I decided to try and modify create_cnn
to allow it to take a path to pretrained network weights and return a Learner
object that has the pretrained weights.
To do this, I looked into what create_cnn
is doing under the hood.
I summarised the steps as follows:
- Call
cnn_config()
function to getmeta
information aboutarch
. - Use the
meta
information to create abody
and a newhead
usingcreate_body()
andcreate_head()
respectively - A new
model
is created by puttingbody
andhead
sequentially. -
ClassificationLearner
is then used to return aLearner
object -
learn.split()
seems to split thelearn.model
at the defined split. Functionally, this seems to split themodel
intohead
andbody
. - If pretrained, freeze up to the last layer
- Initialize the new head with
nn.init.kaiming_normal_
.
I had a question about step 7 above. apply_init(model[1], nn.init.kaiming_normal_)
is used to initialise the model. How is this initialising learn
's model? Is my understanding of the steps right?
Anyway, I wrote a function to load a different pretrained model closely modelled on the create_cnn
function: create_cnn_from_pretrained
. I added a new step in between steps 1 and 2 where I load the weights of the pretrained model. The rest is mostly the same.
def create_cnn_from_pretrained(data:DataBunch, arch:Callable, pretrained_model:PathOrStr, cut:Union[int,Callable]=None,
lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,
custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None,
classification:bool=True, **kwargs:Any)->None:
"Build convnet style learners and load pretrained model from other learn model"
meta = cnn_config(arch)
arch = arch(pretrained=False)
arch.load_state_dict(torch.load(pretrained_model, map_location=data.device), strict=False)
body = create_body(arch, ifnone(cut,meta['cut']))
nf = num_features_model(body) * 2
head = custom_head or create_head(nf, data.c, lin_ftrs, ps)
model = nn.Sequential(body, head)
learn = ClassificationLearner(data, model, **kwargs)
learn.split(ifnone(split_on,meta['split']))
learn.freeze()
apply_init(model[1], nn.init.kaiming_normal_)
return learn
Would appreciate if others could confirm that this makes sense and if it would be something that could be added to the library itself?
EDIT: This method does not work. Setting strict=False
does not solve the problem of the mismatch of weight sizes. When strict=False
PyTorch loads weights that have the same names not different sizes. I have a slightly different workaround below.