Clarification needed about fine_tune function

Hello Everyone,

While retraining our pre-trained model(originally trained for Task A) for let’s say Task B, it is established that in transfer learning we retrain all the layers of the models if we have a large dataset for Task B on the other hand if we have a small dataset for Task B we freeze the earlier layers of the pre-trained model and retrain only the head layer(s). In lesson 1, when we are playing around with Oxford-IIIT Pet Dataset to create a dog vs cat classifier. There we used a pre-trained resnet34 model and took advantage of transfer learning using the fine_tune method of the fastai library. It is stated in the fastbook that the fine_tune method retrains the whole architecture such that in earlier layer weight change is slow and head layer(s) weight change is fast. So I wanted to ask: How can we change the number and type of layers added as head layers in the existing cnn_learner/fine_tune method? How can we customize the number of layers and their type in the head layers? Also is there any way to freeze the earlier layers instead of updating the weights of all the layers of the pre-trained model while fine-tuning(which I think is how learn.fine_tune seems to work)?


PS: It is also mentioned in fastbook that the head layer need not be a single layer but can also be a collection of layers

Ok, so I got the clarification thanks to @arora_aman. In the first epoch fine_tune in fastai trains only head layers and freezes the earlier layer; then from the second epoch onwards it applies weight change operation to all the layers altogether. Also, it should be noted that the weights of the head will change slowly as compared to earlier layers which change slowly.

Hi, take a look at the source code it should clear out things further. In chapter 6 i think jeremy did walk through this.
There is a parameter which can tell for how many epochs you want to freeze and train.

Fine tune with freeze for freeze_epochs then with unfreeze from epochs using discriminative LR

It basically freezes then fit cycle for freeze epoch
Then unfreezes and fits for passed epochs

1 Like

Aha, thanks Surabhi !!