Being used to Keras model.summary
I can’t hold myself to find something similar in Pytorch too.
This is the library we’ll use https://github.com/sksq96/pytorch-summary
Here is the only command you need to run after creating a Learner
object to get the model’s summary :
from torchsummary import summary
channels = 3
H = 224
W = 224
summary(learn.model, input_size=(channels, H, W))
which gives following :
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 36,864
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
BasicBlock-11 [-1, 64, 56, 56] 0
Conv2d-12 [-1, 64, 56, 56] 36,864
BatchNorm2d-13 [-1, 64, 56, 56] 128
ReLU-14 [-1, 64, 56, 56] 0
Conv2d-15 [-1, 64, 56, 56] 36,864
BatchNorm2d-16 [-1, 64, 56, 56] 128
ReLU-17 [-1, 64, 56, 56] 0
BasicBlock-18 [-1, 64, 56, 56] 0
Conv2d-19 [-1, 64, 56, 56] 36,864
BatchNorm2d-20 [-1, 64, 56, 56] 128
ReLU-21 [-1, 64, 56, 56] 0
Conv2d-22 [-1, 64, 56, 56] 36,864
BatchNorm2d-23 [-1, 64, 56, 56] 128
ReLU-24 [-1, 64, 56, 56] 0
BasicBlock-25 [-1, 64, 56, 56] 0
Conv2d-26 [-1, 128, 28, 28] 73,728
BatchNorm2d-27 [-1, 128, 28, 28] 256
ReLU-28 [-1, 128, 28, 28] 0
Conv2d-29 [-1, 128, 28, 28] 147,456
BatchNorm2d-30 [-1, 128, 28, 28] 256
Conv2d-31 [-1, 128, 28, 28] 8,192
BatchNorm2d-32 [-1, 128, 28, 28] 256
ReLU-33 [-1, 128, 28, 28] 0
BasicBlock-34 [-1, 128, 28, 28] 0
Conv2d-35 [-1, 128, 28, 28] 147,456
BatchNorm2d-36 [-1, 128, 28, 28] 256
ReLU-37 [-1, 128, 28, 28] 0
Conv2d-38 [-1, 128, 28, 28] 147,456
BatchNorm2d-39 [-1, 128, 28, 28] 256
ReLU-40 [-1, 128, 28, 28] 0
BasicBlock-41 [-1, 128, 28, 28] 0
Conv2d-42 [-1, 128, 28, 28] 147,456
BatchNorm2d-43 [-1, 128, 28, 28] 256
ReLU-44 [-1, 128, 28, 28] 0
Conv2d-45 [-1, 128, 28, 28] 147,456
BatchNorm2d-46 [-1, 128, 28, 28] 256
ReLU-47 [-1, 128, 28, 28] 0
BasicBlock-48 [-1, 128, 28, 28] 0
Conv2d-49 [-1, 128, 28, 28] 147,456
BatchNorm2d-50 [-1, 128, 28, 28] 256
ReLU-51 [-1, 128, 28, 28] 0
Conv2d-52 [-1, 128, 28, 28] 147,456
BatchNorm2d-53 [-1, 128, 28, 28] 256
ReLU-54 [-1, 128, 28, 28] 0
BasicBlock-55 [-1, 128, 28, 28] 0
Conv2d-56 [-1, 256, 14, 14] 294,912
BatchNorm2d-57 [-1, 256, 14, 14] 512
ReLU-58 [-1, 256, 14, 14] 0
Conv2d-59 [-1, 256, 14, 14] 589,824
BatchNorm2d-60 [-1, 256, 14, 14] 512
Conv2d-61 [-1, 256, 14, 14] 32,768
BatchNorm2d-62 [-1, 256, 14, 14] 512
ReLU-63 [-1, 256, 14, 14] 0
BasicBlock-64 [-1, 256, 14, 14] 0
Conv2d-65 [-1, 256, 14, 14] 589,824
BatchNorm2d-66 [-1, 256, 14, 14] 512
ReLU-67 [-1, 256, 14, 14] 0
Conv2d-68 [-1, 256, 14, 14] 589,824
BatchNorm2d-69 [-1, 256, 14, 14] 512
ReLU-70 [-1, 256, 14, 14] 0
BasicBlock-71 [-1, 256, 14, 14] 0
Conv2d-72 [-1, 256, 14, 14] 589,824
BatchNorm2d-73 [-1, 256, 14, 14] 512
ReLU-74 [-1, 256, 14, 14] 0
Conv2d-75 [-1, 256, 14, 14] 589,824
BatchNorm2d-76 [-1, 256, 14, 14] 512
ReLU-77 [-1, 256, 14, 14] 0
BasicBlock-78 [-1, 256, 14, 14] 0
Conv2d-79 [-1, 256, 14, 14] 589,824
BatchNorm2d-80 [-1, 256, 14, 14] 512
ReLU-81 [-1, 256, 14, 14] 0
Conv2d-82 [-1, 256, 14, 14] 589,824
BatchNorm2d-83 [-1, 256, 14, 14] 512
ReLU-84 [-1, 256, 14, 14] 0
BasicBlock-85 [-1, 256, 14, 14] 0
Conv2d-86 [-1, 256, 14, 14] 589,824
BatchNorm2d-87 [-1, 256, 14, 14] 512
ReLU-88 [-1, 256, 14, 14] 0
Conv2d-89 [-1, 256, 14, 14] 589,824
BatchNorm2d-90 [-1, 256, 14, 14] 512
ReLU-91 [-1, 256, 14, 14] 0
BasicBlock-92 [-1, 256, 14, 14] 0
Conv2d-93 [-1, 256, 14, 14] 589,824
BatchNorm2d-94 [-1, 256, 14, 14] 512
ReLU-95 [-1, 256, 14, 14] 0
Conv2d-96 [-1, 256, 14, 14] 589,824
BatchNorm2d-97 [-1, 256, 14, 14] 512
ReLU-98 [-1, 256, 14, 14] 0
BasicBlock-99 [-1, 256, 14, 14] 0
Conv2d-100 [-1, 512, 7, 7] 1,179,648
BatchNorm2d-101 [-1, 512, 7, 7] 1,024
ReLU-102 [-1, 512, 7, 7] 0
Conv2d-103 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-104 [-1, 512, 7, 7] 1,024
Conv2d-105 [-1, 512, 7, 7] 131,072
BatchNorm2d-106 [-1, 512, 7, 7] 1,024
ReLU-107 [-1, 512, 7, 7] 0
BasicBlock-108 [-1, 512, 7, 7] 0
Conv2d-109 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-110 [-1, 512, 7, 7] 1,024
ReLU-111 [-1, 512, 7, 7] 0
Conv2d-112 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-113 [-1, 512, 7, 7] 1,024
ReLU-114 [-1, 512, 7, 7] 0
BasicBlock-115 [-1, 512, 7, 7] 0
Conv2d-116 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-117 [-1, 512, 7, 7] 1,024
ReLU-118 [-1, 512, 7, 7] 0
Conv2d-119 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-120 [-1, 512, 7, 7] 1,024
ReLU-121 [-1, 512, 7, 7] 0
BasicBlock-122 [-1, 512, 7, 7] 0
AdaptiveMaxPool2d-123 [-1, 512, 1, 1] 0
AdaptiveAvgPool2d-124 [-1, 512, 1, 1] 0
AdaptiveConcatPool2d-125 [-1, 1024, 1, 1] 0
Lambda-126 [-1, 1024] 0
BatchNorm1d-127 [-1, 1024] 2,048
Dropout-128 [-1, 1024] 0
Linear-129 [-1, 512] 524,800
ReLU-130 [-1, 512] 0
BatchNorm1d-131 [-1, 512] 1,024
Dropout-132 [-1, 512] 0
Linear-133 [-1, 3] 1,539
================================================================
Total params: 21,814,083
Trainable params: 546,435
Non-trainable params: 21,267,648
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 96.33
Params size (MB): 83.21
Estimated Total Size (MB): 180.12
----------------------------------------------------------------
we can also verify total no of parmas easily :
sum([param.nelement() for param in learn.model.parameters()])
which gives :
21814083
which is same as above.