Transfer learning when the number of class changes

Hi ! First of all thanks to all the people that made this course so awesome.

Following lesson 1, I decided to create my own dataset made of skiers and snowboards to see if I could teach resnet34 to recognize them. It worked pretty well with an accuracy around 91% (My very basic render app is available here if you want to give it a try)

Going on with the course, I decided to spice things up, adding images of monoskiers, ski jumpers and cross country skiers to my dataset. I trained a new resnet34 and obtained an accuracy around 75%(with monoskiers being class with the worst performance, as I expected). Looking more precisely at the results, I saw that there was more confusion between skiers and snowboards in that second model than there was in the first one. The usage of transfert learning in lesson 3 (for the planet dataset) made me think that there should be a way to inherit the knowledge acquired in my first model and apply it to the second.

I naively tried doing learn.data = data5 (where learn is the first model and data5 is the data bunch for the 5 class dataset) and then went on with the fitting but it logically failed because the number of classes is different in the two classification problems hence you can’t copy the weight’s values from the last layer.

My question is the following : Is there a way to copy only a portion of the weights from one network to the other ? I assume there is but can’t find the correct way to do it.

Thanks a lot for your help !

You can remove the “head” or the part of the network that is looking at the features to map it to classes, then retrain that. You can freeze the weights of the previous layers or the “body” so that they are not affected. The head is meant to be removed, for this exact problem.

Does this all sound familiar to you? I haven’t done part 1 in awhile and don’t remember what you have learned by this point.

I was working on a problem and my model was getting confused between 2 classes a lot (out of 5). So i decided to train it with the non confused classes first and then add the other one later. It didn’t help much. You can find the jupyter notebook here and the article here

1 Like

That could work, I’m gonna try it out ! So you basically crated a dataset with 6 classes but where 2 of them had neither training or validation examples and then you retrained it by adding the last two classes ?

A quick update, I’ve found what I needed by digging deeper into that forum, The answer is here and the line that in needed was :
learn-5-classes.model[0].load_state_dict(learn-2-classes.model[0].state_dict()) (that kind of does what @marii said).