Splitting a pretrained model in groups of layers

In the docs for create_cnn, one reads:

The final model obtained by stacking the backbone and the head (custom or defined as we saw) is then separated in groups for gradual unfreezeing or differential learning rates. You can specify of to split the backbone in groups with the optional argument split_on (should be a function that returns those groups when given the backbone).

I’d like to know:

  1. What is meant by gradual unfreezeing
  2. If someone could provide an example of a function which “returns groups when given the backbone”.


For your first question, I would say that it refers to the fact that when you are using a pretrained model, only the last group is not frozen (the custom head added by the fastai lib consisting in two fully connected layers with batch norm and dropout, or something like that. You can check it in the source code), and you unfreeze group by group your model layers.
If you create a model using learn = create_cnn(…), you can do learn.freeze_to(-2) to unfreeze the second last layer group in addition to the last layer group, and then, learn.unfreeze() to unfreeze all the layer groups if you have 3 groups in your model (I think this is the case for resnet34 arch).
Each time you unfreeze a layer group, you train your model before unfreezing a new layer group. That would be the gradual unfreezeing.

For your second question, I think you might access the layer groups by using learn.model, and indexing it with the layer group number.
learn.model[0] would refer to the first layers group of your model, and learn.model[-1] the layer group of your model head
learn.model[0][0] would give you the first layer of the first group

1 Like

Very informative answer, thank you!

We already seen more than once how to refer specific layers, but I just thought, reading the docs, that one was supposed to provide a custom function.

Oh sorry I read too fast. I didn’t see the split_on argument thing you were asking about. Not used yet :frowning:

1 Like

The function needs to return a tuple of the splitpoints. Example from the library itself:

def _resnet_split(m): return (m[0][6],m[1])

from here:


No problem :slight_smile: Your answer was instructive regardless of that.

Thanks, problem solved! Looking at the code is always a good idea. I reckon I should have done it before asking.

1 Like