How can we see what layers are frozen?

In fastai/pytorch v1, how can we see which layers are trainable and which layers are frozen?

1 Like
1 Like

I am getting this error:

AttributeError: ‘WeightDropout’ object has no attribute ‘trainable’

learn.model_summary()

1 Like

First of all, to see how many layers are in your model, so you know how many times you should do .freeze_to, check learn_classifier.layer_groups like this:

 for index, layer in enumerate(learn.layer_groups):
  print('Layer Group Index: ', index, layer)

To see what happens when you do .freeze_to, I have created a function that prints the number of trainable parameters so you can see which layers are frozen.

def summary_trainable(learner):
  result = []
  total_params_element = 0
  def check_trainable(module):
    nonlocal total_params_element
    if len(list(module.children())) == 0:
      num_param = 0
      num_trainable_param = 0
      num_param_numel = 0
      for parameter in module.parameters():
        num_param += 1
        if parameter.requires_grad:
          num_param_numel += parameter.numel()
          total_params_element += parameter.numel()
          num_trainable_param += 1

      result.append({'module': module, 'num_param': num_param , 'num_trainable_param' : num_trainable_param, 'num_param_numel': num_param_numel})
  learner.model.apply(check_trainable)
  
  print("{: <85} {: <17} {: <20} {: <40}".format('Module Name', 'Total Parameters', 'Trainable Parameters', '# Elements in Trainable Parametrs'))
  for row in result:
    print("{: <85} {: <17} {: <20} {: <40,}".format(row['module'].__str__(), row['num_param'], row['num_trainable_param'], row['num_param_numel']))
  print('Total number of parameters elements {:,}'.format(total_params_element))
learn.unfreeze()
summary_trainable(learn)

please note that learn.summary() doesn’t show this information. at least not in the case of RNN_Learner.

for example:

learn_classifier = text_classifier_learner(data , AWD_LSTM, drop_mult=0)
learn_classifier.unfreeze()
learn_classifier.summary()
"""
SequentialRNN
======================================================================
Layer (type)         Output Shape         Param #    Trainable 
======================================================================
RNNDropout           [61, 400]            0          False     
______________________________________________________________________
RNNDropout           [61, 1152]           0          False     
______________________________________________________________________
RNNDropout           [61, 1152]           0          False     
______________________________________________________________________
BatchNorm1d          [1200]               2,400      True      
______________________________________________________________________
Linear               [50]                 60,050     True      
______________________________________________________________________
ReLU                 [50]                 0          False     
______________________________________________________________________
BatchNorm1d          [50]                 100        True      
______________________________________________________________________
Dropout              [50]                 0          False     
______________________________________________________________________
Linear               [2]                  102        True      
______________________________________________________________________

Total params: 62,652
Total trainable params: 62,652
Total non-trainable params: 0
Optimized with 'torch.optim.adam.Adam', betas=(0.9, 0.99)
Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ 
Loss function : FlattenedLoss
======================================================================
Callbacks functions applied 
    RNNTrainer
"""

which is not correct. The RNNDropout layers have an LSTM module that each has trainable parameters. Also, the Total params reported in the summary did not include the embedding or LSTM parameters.

if you use the summary_trainable function (scroll right, it’s longer than what you see):

learn_classifier.unfreeze()
summary_trainable(learn_classifier)
"""
Module Name                                                                           Total Parameters  Trainable Parameters # Elements in Trainable Parametrs       
Embedding(5536, 400, padding_idx=1)                                                   1                 1                    2,214,400                               
Embedding(5536, 400, padding_idx=1)                                                   1                 1                    2,214,400                               
LSTM(400, 1152, batch_first=True)                                                     4                 4                    7,160,832                               
LSTM(1152, 1152, batch_first=True)                                                    4                 4                    10,626,048                              
LSTM(1152, 400, batch_first=True)                                                     4                 4                    2,486,400                               
RNNDropout()                                                                          0                 0                    0                                       
RNNDropout()                                                                          0                 0                    0                                       
RNNDropout()                                                                          0                 0                    0                                       
RNNDropout()                                                                          0                 0                    0                                       
BatchNorm1d(1200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)     2                 2                    2,400                                   
Linear(in_features=1200, out_features=50, bias=True)                                  2                 2                    60,050                                  
ReLU(inplace=True)                                                                    0                 0                    0                                       
BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)       2                 2                    100                                     
Dropout(p=0.1, inplace=False)                                                         0                 0                    0                                       
Linear(in_features=50, out_features=2, bias=True)                                     2                 2                    102                                     
Total number of parameters elements 24,764,732
"""
2 Likes