In lesson 7 Jeremy shows how to create a basic CNN with batchnorm using grayscale images. I was attempting to emulate this but using colour images (3 channels), however the model summary shows 192 instead of 12. What did I do wrong?
xb, yb = data.one_batch()
xb.shape, yb.shape
(torch.Size([32, 3, 128, 128]), torch.Size([32]))
def conv(ni, nf): return nn.Conv2d(ni, nf, 3, stride=2, padding=1)
model = nn.Sequential(
conv(3, 12), # 64
nn.BatchNorm2d(12),
nn.ReLU(),
conv(12, 24), # 32
nn.BatchNorm2d(24),
nn.ReLU(),
conv(24, 48), # 16
nn.BatchNorm2d(48),
nn.ReLU(),
conv(48, 24), # 8
nn.BatchNorm2d(24),
nn.ReLU(),
conv(24, 12), # 4
nn.BatchNorm2d(12),
Flatten() # remove (4,4) grid
Sequential
======================================================================
Layer (type) Output Shape Param # Trainable
======================================================================
Conv2d [12, 64, 64] 336 True
______________________________________________________________________
BatchNorm2d [12, 64, 64] 24 True
______________________________________________________________________
ReLU [12, 64, 64] 0 False
______________________________________________________________________
Conv2d [24, 32, 32] 2,616 True
______________________________________________________________________
BatchNorm2d [24, 32, 32] 48 True
______________________________________________________________________
ReLU [24, 32, 32] 0 False
______________________________________________________________________
Conv2d [48, 16, 16] 10,416 True
______________________________________________________________________
BatchNorm2d [48, 16, 16] 96 True
______________________________________________________________________
ReLU [48, 16, 16] 0 False
______________________________________________________________________
Conv2d [24, 8, 8] 10,392 True
______________________________________________________________________
BatchNorm2d [24, 8, 8] 48 True
______________________________________________________________________
ReLU [24, 8, 8] 0 False
______________________________________________________________________
Conv2d [12, 4, 4] 2,604 True
______________________________________________________________________
BatchNorm2d [12, 4, 4] 24 True
______________________________________________________________________
Flatten [192] 0 False
______________________________________________________________________
Total params: 26,604
Total trainable params: 26,604
Total non-trainable params: 0
Optimized with 'torch.optim.adam.Adam', betas=(0.9, 0.99)
Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/
Loss function : CrossEntropyLoss
======================================================================
Callbacks functions applied