Hi,
I am working on a classification problem for which a taxonomy is available. To make it simple, let’s consider that there are 3 levels in this taxonomy:
- level_1 with 3 classes
- level_2 with 10 classes (child of level_1)
- level_3 with 100 classes (child of level_2)
The ultimate goal is to predict the class associated to level_3. I wanted to do the following experiment.
- Start by using a pretrained model (imagenet) to train on level_1 classification
- Transfer model weights from previous step and train on level_2 classification
- Transfer model weights from previous step and train on level_3 classification
I did not really find an elegant way to do it using the fastai lib, so I was wondering whether I missed something. Here is what I do so far:
# level 1 training
learn = cnn_learner(level_1_data, models.resnet50, metrics=accuracy)
learn.fit_one_cycle(5)
learn.save('level_1')
# level 2 training from level 1 model
learn = cnn_learner(level_2_data, models.resnet50, metrics=accuracy)
checkpoint_dict = torch.load('level_1')
pretrained_dict = checkpoint_dict['model']
model_dict = learn.model.state_dict()
# 1. filter out the linear layer weights
pretrained_dict = {k: v for k, v in pretrained_dict.items() if '1.8' not in k}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
learn.model.load_state_dict(model_dict)
Any ideas would be welcome.
Thanks