Keras like summary of a model

Being used to Keras model.summary I can’t hold myself to find something similar in Pytorch too.

This is the library we’ll use https://github.com/sksq96/pytorch-summary

Here is the only command you need to run after creating a Learner object to get the model’s summary :

from torchsummary import summary
channels = 3
H = 224
W = 224
summary(learn.model, input_size=(channels, H, W)) 

which gives following : :slight_smile:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19           [-1, 64, 56, 56]          36,864
      BatchNorm2d-20           [-1, 64, 56, 56]             128
             ReLU-21           [-1, 64, 56, 56]               0
           Conv2d-22           [-1, 64, 56, 56]          36,864
      BatchNorm2d-23           [-1, 64, 56, 56]             128
             ReLU-24           [-1, 64, 56, 56]               0
       BasicBlock-25           [-1, 64, 56, 56]               0
           Conv2d-26          [-1, 128, 28, 28]          73,728
      BatchNorm2d-27          [-1, 128, 28, 28]             256
             ReLU-28          [-1, 128, 28, 28]               0
           Conv2d-29          [-1, 128, 28, 28]         147,456
      BatchNorm2d-30          [-1, 128, 28, 28]             256
           Conv2d-31          [-1, 128, 28, 28]           8,192
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 128, 28, 28]         147,456
      BatchNorm2d-36          [-1, 128, 28, 28]             256
             ReLU-37          [-1, 128, 28, 28]               0
           Conv2d-38          [-1, 128, 28, 28]         147,456
      BatchNorm2d-39          [-1, 128, 28, 28]             256
             ReLU-40          [-1, 128, 28, 28]               0
       BasicBlock-41          [-1, 128, 28, 28]               0
           Conv2d-42          [-1, 128, 28, 28]         147,456
      BatchNorm2d-43          [-1, 128, 28, 28]             256
             ReLU-44          [-1, 128, 28, 28]               0
           Conv2d-45          [-1, 128, 28, 28]         147,456
      BatchNorm2d-46          [-1, 128, 28, 28]             256
             ReLU-47          [-1, 128, 28, 28]               0
       BasicBlock-48          [-1, 128, 28, 28]               0
           Conv2d-49          [-1, 128, 28, 28]         147,456
      BatchNorm2d-50          [-1, 128, 28, 28]             256
             ReLU-51          [-1, 128, 28, 28]               0
           Conv2d-52          [-1, 128, 28, 28]         147,456
      BatchNorm2d-53          [-1, 128, 28, 28]             256
             ReLU-54          [-1, 128, 28, 28]               0
       BasicBlock-55          [-1, 128, 28, 28]               0
           Conv2d-56          [-1, 256, 14, 14]         294,912
      BatchNorm2d-57          [-1, 256, 14, 14]             512
             ReLU-58          [-1, 256, 14, 14]               0
           Conv2d-59          [-1, 256, 14, 14]         589,824
      BatchNorm2d-60          [-1, 256, 14, 14]             512
           Conv2d-61          [-1, 256, 14, 14]          32,768
      BatchNorm2d-62          [-1, 256, 14, 14]             512
             ReLU-63          [-1, 256, 14, 14]               0
       BasicBlock-64          [-1, 256, 14, 14]               0
           Conv2d-65          [-1, 256, 14, 14]         589,824
      BatchNorm2d-66          [-1, 256, 14, 14]             512
             ReLU-67          [-1, 256, 14, 14]               0
           Conv2d-68          [-1, 256, 14, 14]         589,824
      BatchNorm2d-69          [-1, 256, 14, 14]             512
             ReLU-70          [-1, 256, 14, 14]               0
       BasicBlock-71          [-1, 256, 14, 14]               0
           Conv2d-72          [-1, 256, 14, 14]         589,824
      BatchNorm2d-73          [-1, 256, 14, 14]             512
             ReLU-74          [-1, 256, 14, 14]               0
           Conv2d-75          [-1, 256, 14, 14]         589,824
      BatchNorm2d-76          [-1, 256, 14, 14]             512
             ReLU-77          [-1, 256, 14, 14]               0
       BasicBlock-78          [-1, 256, 14, 14]               0
           Conv2d-79          [-1, 256, 14, 14]         589,824
      BatchNorm2d-80          [-1, 256, 14, 14]             512
             ReLU-81          [-1, 256, 14, 14]               0
           Conv2d-82          [-1, 256, 14, 14]         589,824
      BatchNorm2d-83          [-1, 256, 14, 14]             512
             ReLU-84          [-1, 256, 14, 14]               0
       BasicBlock-85          [-1, 256, 14, 14]               0
           Conv2d-86          [-1, 256, 14, 14]         589,824
      BatchNorm2d-87          [-1, 256, 14, 14]             512
             ReLU-88          [-1, 256, 14, 14]               0
           Conv2d-89          [-1, 256, 14, 14]         589,824
      BatchNorm2d-90          [-1, 256, 14, 14]             512
             ReLU-91          [-1, 256, 14, 14]               0
       BasicBlock-92          [-1, 256, 14, 14]               0
           Conv2d-93          [-1, 256, 14, 14]         589,824
      BatchNorm2d-94          [-1, 256, 14, 14]             512
             ReLU-95          [-1, 256, 14, 14]               0
           Conv2d-96          [-1, 256, 14, 14]         589,824
      BatchNorm2d-97          [-1, 256, 14, 14]             512
             ReLU-98          [-1, 256, 14, 14]               0
       BasicBlock-99          [-1, 256, 14, 14]               0
          Conv2d-100            [-1, 512, 7, 7]       1,179,648
     BatchNorm2d-101            [-1, 512, 7, 7]           1,024
            ReLU-102            [-1, 512, 7, 7]               0
          Conv2d-103            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-104            [-1, 512, 7, 7]           1,024
          Conv2d-105            [-1, 512, 7, 7]         131,072
     BatchNorm2d-106            [-1, 512, 7, 7]           1,024
            ReLU-107            [-1, 512, 7, 7]               0
      BasicBlock-108            [-1, 512, 7, 7]               0
          Conv2d-109            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-110            [-1, 512, 7, 7]           1,024
            ReLU-111            [-1, 512, 7, 7]               0
          Conv2d-112            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-113            [-1, 512, 7, 7]           1,024
            ReLU-114            [-1, 512, 7, 7]               0
      BasicBlock-115            [-1, 512, 7, 7]               0
          Conv2d-116            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-117            [-1, 512, 7, 7]           1,024
            ReLU-118            [-1, 512, 7, 7]               0
          Conv2d-119            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-120            [-1, 512, 7, 7]           1,024
            ReLU-121            [-1, 512, 7, 7]               0
      BasicBlock-122            [-1, 512, 7, 7]               0
AdaptiveMaxPool2d-123            [-1, 512, 1, 1]               0
AdaptiveAvgPool2d-124            [-1, 512, 1, 1]               0
AdaptiveConcatPool2d-125           [-1, 1024, 1, 1]               0
          Lambda-126                 [-1, 1024]               0
     BatchNorm1d-127                 [-1, 1024]           2,048
         Dropout-128                 [-1, 1024]               0
          Linear-129                  [-1, 512]         524,800
            ReLU-130                  [-1, 512]               0
     BatchNorm1d-131                  [-1, 512]           1,024
         Dropout-132                  [-1, 512]               0
          Linear-133                    [-1, 3]           1,539
================================================================
Total params: 21,814,083
Trainable params: 546,435
Non-trainable params: 21,267,648
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 96.33
Params size (MB): 83.21
Estimated Total Size (MB): 180.12
----------------------------------------------------------------

we can also verify total no of parmas easily :
sum([param.nelement() for param in learn.model.parameters()])
which gives :
21814083 which is same as above.

14 Likes

Thanks for the tip. We do have a todo to add that functionality btw - just haven’t gotten around to it yet.

If anyone is interested, doing that PR would be a great way to learn about HookCallback etc in the fastai lib.

1 Like

There’s also torchstat which can infer shapes automatically in a model (was recently started by Soumith, so I got a notification)

3 Likes

I could probably give this a shot Jeremy. Is there a rough API you have in mind for the ‘summarizer’ ?
examples

  • a method on the learner object itself.
  • a separate function that processes some data returned by the learner object.

was just browsing the HookCallback and it kinda makes sense to me (run some functions by hooking into the training lifecycle ---- i know there’re proper docs :smile:, will dig into it too)

1 Like

Interested to have a look, I did try to implement a simple one for fastai 0.7

Circling back, is this done/someone working on it already? I want to start looking at the callback this week.
Base on the info earlier in this post, this seems not too difficult by taking advantage of existing codes.

I may start looking at

  1. Add a Keras like Summary function
  2. A simple graphviz for the model which takes the summary (a dict or nametuple)

Edited: Started a PR for this, in case someone is working on it.

The current Hook seems only save the input and output of a nn.Module, but not the layer within a Module, e.g. 2 conv layer and batchnorm inside a BasicBlock. The only way I can think of to recover these details is to just iterate the entire model and get their shape of weights/bias, but I cannot think of how to leverage hook.

Those inner things are nn.Modules too. And you can pass any modules you like to the hook callbacks etc when you create them.

The current model sizes function does not accept a nn.Module like BasicBlock, I try to find reference of this function but does not find other module calling this function. Can I modify this function, or is it used somewhere else? In the doc I found a comment suggesting it maybe useful for DynamicUnet but it is not used for the unet class.

def model_sizes(m:nn.Module, size:tuple=(64,64), full:bool=True) -> Tuple[Sizes,Tensor,Hooks]:
    "Pass a dummy input through the model to get the various sizes. Returns (res,x,hooks) if `full`"
    hooks = hook_outputs(m)
    ch_in = in_channels(m)
    x = next(m.parameters()).new(1,ch_in,*size)
    x = m.eval()(x)
    res = [o.stored.shape for o in hooks]
    if not full: hooks.remove()
    return (res,x,hooks) if full else res

How did you check? Your editor should be able to tell you what calls each function. Otherwise, use ack or grep. e.g.

So yup, it is used in unet. The code shows it’s just calling hook_outputs, which according to the docs can take as many modules as you want:

Weird, I may need to check my editor. It does not show up the reference of unet, but when I go to unet.py I did saw model_sizes is used…

It only works when I wrap a nn.Module inside a Sequential object, but what I want is to get output of every layer inside of a block. I am thinking how to hook at a m.children() level, takes BasicBlock as example it should have hook on all 7 layers inside it.


I assume you can just pass m.children to it, since hook_outputs can take as many modules as you want.

1 Like

Thanks Jeremy, that’s very helpful, I think I have a rough idea how I need to implement it.

Edited:

PR: https://github.com/fastai/fastai/pull/1198

Turn out it is extremely easy to implement, but I spend a lot of time to reading the code. I am so excited when I finally got it right, Not realizing I can use flatten_model simply hook on the lowest Module level. The hook is easy to use once I understand how it works. I struggle for 2 days, but really I only wrote a few lines of code for the functionality needed. The majority of the code is simply formatting the summary.

Reading the code how hook and hooks are implemented is not too helpful, all I need to know is how to use it. I hope I can come up with some simpler example later to help other understand the usage of it.

9 Likes

Nice work @nok It’d be nice if we can have another column to see if the layer is current set to trainable or not. :slight_smile:

Thanks, Yeah now that is merged I have planned to add some more information for the summary.

started a separate thread as following the new dev projects index.

Found an existing library can plot a simple graph with graphviz like this, would this be useful?

4 Likes

There is also an interesting NN viewer:

And here is a demo link.

1 Like