How do you verify splitting learner model layers?

Hi, I’m creating a splitter in my learner to group the different layers, and later look at training each group with a different learning rate. I’ve got a simple learner being built as shown below. But how can I verify that there are actually 3 groups created in my learner model? What’s the function or variable to look at?

class MyNeuralNetwork(nn.Module):
    def __init__(self, arch= models.resnet18):
        super().__init__()
        self.cnn = create_body(arch)
        self.head = create_head(num_features_model(self.cnn) * 2, 4)

    def forward(self, im):
        x = self.cnn(im)
        x = self.head(x)
        return 2 * (x.sigmoid_() - 0.5)

learner = Learner(dls = (..)), model = MyNeuralNetwork(arch=models.resnet18))

def learner_splitter(model):
    return [params(model.cnn[:5]), params(model.cnn[5:]), params(model.head)]

learner.splitter = learner_splitter

learn.summary() should do the trick.

1 Like

learn.summary doesn’t tell you directly, right? (Unless you’re super familiar with the network architecture and can tell by looking at the trainable attribute).
I found that sequentially freezing the layers using learn.freeze_to(i), where i iterates from 1 to 3, and then looking at the non-trainable parameters value from learn.summary helps to verify this. Is there an easier way?

1 Like

summary will tell you how many groups there are.

That seems like the right approach to me to verify where they are… I wonder if a modification could be made to it to provide a separate frozen section…

2 Likes

I found that information now, thanks @muellerzr!

Just an observation for anyone else looking at the same: learn.summary only tells you the number of groups in the model if you run learn.freeze_to(x), learn.summary() where x is >= number of groups (and it’s told as a warning before the tabular output, since it’ll be freezing all groups then).

In case x is less than the number of groups, the only information you get is a
Model frozen up to parameter group number x

2 Likes