Why is it the `BatchNorm2d ` layers in a frozen model trainable?


(Junlin) #1

Hi! I was just experiment with the “lesson 6 pets more” notebook, and I noticed that all the BatchNorm2d layers in the frozen model are trainable. I am curious why.

print(learn.summary())
======================================================================
Layer (type)         Output Shape         Param #    Trainable 
======================================================================
Conv2d               [1, 64, 176, 176]    9,408      False     
______________________________________________________________________
BatchNorm2d          [1, 64, 176, 176]    128        True      
______________________________________________________________________
ReLU                 [1, 64, 176, 176]    0          False     
______________________________________________________________________
MaxPool2d            [1, 64, 88, 88]      0          False     
______________________________________________________________________
Conv2d               [1, 64, 88, 88]      36,864     False     
______________________________________________________________________
BatchNorm2d          [1, 64, 88, 88]      128        True      
______________________________________________________________________
ReLU                 [1, 64, 88, 88]      0          False     
______________________________________________________________________
Conv2d               [1, 64, 88, 88]      36,864     False     
______________________________________________________________________
BatchNorm2d          [1, 64, 88, 88]      128        True      
______________________________________________________________________
Conv2d               [1, 64, 88, 88]      36,864     False     
______________________________________________________________________
BatchNorm2d          [1, 64, 88, 88]      128        True      
______________________________________________________________________
ReLU                 [1, 64, 88, 88]      0          False     
______________________________________________________________________
Conv2d               [1, 64, 88, 88]      36,864     False     
______________________________________________________________________
BatchNorm2d          [1, 64, 88, 88]      128        True      
______________________________________________________________________
Conv2d               [1, 64, 88, 88]      36,864     False     
______________________________________________________________________
BatchNorm2d          [1, 64, 88, 88]      128        True      
______________________________________________________________________
ReLU                 [1, 64, 88, 88]      0          False     
______________________________________________________________________
Conv2d               [1, 64, 88, 88]      36,864     False     
______________________________________________________________________
BatchNorm2d          [1, 64, 88, 88]      128        True      
______________________________________________________________________
Conv2d               [1, 128, 44, 44]     73,728     False     
______________________________________________________________________
BatchNorm2d          [1, 128, 44, 44]     256        True      
______________________________________________________________________
ReLU                 [1, 128, 44, 44]     0          False     
______________________________________________________________________
Conv2d               [1, 128, 44, 44]     147,456    False     
______________________________________________________________________
BatchNorm2d          [1, 128, 44, 44]     256        True      
______________________________________________________________________
Conv2d               [1, 128, 44, 44]     8,192      False     
______________________________________________________________________
BatchNorm2d          [1, 128, 44, 44]     256        True      
______________________________________________________________________
Conv2d               [1, 128, 44, 44]     147,456    False     
______________________________________________________________________
BatchNorm2d          [1, 128, 44, 44]     256        True      
______________________________________________________________________
ReLU                 [1, 128, 44, 44]     0          False     
______________________________________________________________________
Conv2d               [1, 128, 44, 44]     147,456    False     
______________________________________________________________________
BatchNorm2d          [1, 128, 44, 44]     256        True      
______________________________________________________________________
Conv2d               [1, 128, 44, 44]     147,456    False     
______________________________________________________________________
BatchNorm2d          [1, 128, 44, 44]     256        True      
______________________________________________________________________
ReLU                 [1, 128, 44, 44]     0          False     
______________________________________________________________________
Conv2d               [1, 128, 44, 44]     147,456    False     
______________________________________________________________________
BatchNorm2d          [1, 128, 44, 44]     256        True      
______________________________________________________________________
Conv2d               [1, 128, 44, 44]     147,456    False     
______________________________________________________________________
BatchNorm2d          [1, 128, 44, 44]     256        True      
______________________________________________________________________
ReLU                 [1, 128, 44, 44]     0          False     
______________________________________________________________________
Conv2d               [1, 128, 44, 44]     147,456    False     
______________________________________________________________________
BatchNorm2d          [1, 128, 44, 44]     256        True      
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     294,912    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
ReLU                 [1, 256, 22, 22]     0          False     
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     32,768     False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
ReLU                 [1, 256, 22, 22]     0          False     
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
ReLU                 [1, 256, 22, 22]     0          False     
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
ReLU                 [1, 256, 22, 22]     0          False     
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
ReLU                 [1, 256, 22, 22]     0          False     
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
ReLU                 [1, 256, 22, 22]     0          False     
______________________________________________________________________
Conv2d               [1, 256, 22, 22]     589,824    False     
______________________________________________________________________
BatchNorm2d          [1, 256, 22, 22]     512        True      
______________________________________________________________________
Conv2d               [1, 512, 11, 11]     1,179,648  False     
______________________________________________________________________
BatchNorm2d          [1, 512, 11, 11]     1,024      True      
______________________________________________________________________
ReLU                 [1, 512, 11, 11]     0          False     
______________________________________________________________________
Conv2d               [1, 512, 11, 11]     2,359,296  False     
______________________________________________________________________
BatchNorm2d          [1, 512, 11, 11]     1,024      True      
______________________________________________________________________
Conv2d               [1, 512, 11, 11]     131,072    False     
______________________________________________________________________
BatchNorm2d          [1, 512, 11, 11]     1,024      True      
______________________________________________________________________
Conv2d               [1, 512, 11, 11]     2,359,296  False     
______________________________________________________________________
BatchNorm2d          [1, 512, 11, 11]     1,024      True      
______________________________________________________________________
ReLU                 [1, 512, 11, 11]     0          False     
______________________________________________________________________
Conv2d               [1, 512, 11, 11]     2,359,296  False     
______________________________________________________________________
BatchNorm2d          [1, 512, 11, 11]     1,024      True      
______________________________________________________________________
Conv2d               [1, 512, 11, 11]     2,359,296  False     
______________________________________________________________________
BatchNorm2d          [1, 512, 11, 11]     1,024      True      
______________________________________________________________________
ReLU                 [1, 512, 11, 11]     0          False     
______________________________________________________________________
Conv2d               [1, 512, 11, 11]     2,359,296  False     
______________________________________________________________________
BatchNorm2d          [1, 512, 11, 11]     1,024      True      
______________________________________________________________________
AdaptiveAvgPool2d    [1, 512, 1, 1]       0          False     
______________________________________________________________________
AdaptiveMaxPool2d    [1, 512, 1, 1]       0          False     
______________________________________________________________________
Flatten              [1, 1024]            0          False     
______________________________________________________________________
BatchNorm1d          [1, 1024]            2,048      True      
______________________________________________________________________
Dropout              [1, 1024]            0          False     
______________________________________________________________________
Linear               [1, 512]             524,800    True      
______________________________________________________________________
ReLU                 [1, 512]             0          False     
______________________________________________________________________
BatchNorm1d          [1, 512]             1,024      True      
______________________________________________________________________
Dropout              [1, 512]             0          False     
______________________________________________________________________
Linear               [1, 37]              18,981     True      
______________________________________________________________________
BatchNorm1d          [1, 37]              74         True      
______________________________________________________________________

Total params: 21,831,599
Total trainable params: 563,951
Total non-trainable params: 21,267,648


(Brad) #2

This is an option on Learner (defaulting to True)

If train_bn , batchnorm layer learnable params are trained even for frozen layer groups.

Via https://docs.fast.ai/basic_train.html#Learner


(Junlin) #3

Thanks for reply, Brad!

But I still can’t form the instinct why would that make sense. Is there a explanation behind it? Or is it just a best practice from experience?