Unexpected last-layer-activation's shape Imagenette VGG16

When running the following on Imagenette

path = untar_data(URLs.IMAGENETTE)

def get_dls(bs, size):
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   batch_tfms=[*aug_transforms(size=size, min_scale=0.75),
    return dblock.dataloaders(path, bs=bs)

dls = get_dls(128, 128)
learn = Learner(dls, vgg16_bn(), loss_func=CrossEntropyLossFlat(), 

learn.fit_one_cycle(1, 3e-3)

d = learn.model.state_dict()

for k, v in d.items():
    print(k, d[k].shape)

I get

features.0.weight torch.Size([64, 3, 3, 3])
features.0.bias torch.Size([64])
features.1.weight torch.Size([64])
features.1.bias torch.Size([64])
features.1.running_mean torch.Size([64])
features.1.running_var torch.Size([64])
features.1.num_batches_tracked torch.Size([])
features.3.weight torch.Size([64, 64, 3, 3])
features.3.bias torch.Size([64])
features.4.weight torch.Size([64])
features.4.bias torch.Size([64])
features.4.running_mean torch.Size([64])
features.4.running_var torch.Size([64])
features.4.num_batches_tracked torch.Size([])
features.7.weight torch.Size([128, 64, 3, 3])
features.7.bias torch.Size([128])
features.8.weight torch.Size([128])
features.8.bias torch.Size([128])
features.8.running_mean torch.Size([128])
features.8.running_var torch.Size([128])
features.8.num_batches_tracked torch.Size([])
features.10.weight torch.Size([128, 128, 3, 3])
features.10.bias torch.Size([128])
features.11.weight torch.Size([128])
features.11.bias torch.Size([128])
features.11.running_mean torch.Size([128])
features.11.running_var torch.Size([128])
features.11.num_batches_tracked torch.Size([])
features.14.weight torch.Size([256, 128, 3, 3])
features.14.bias torch.Size([256])
features.15.weight torch.Size([256])
features.15.bias torch.Size([256])
features.15.running_mean torch.Size([256])
features.15.running_var torch.Size([256])
features.15.num_batches_tracked torch.Size([])
features.17.weight torch.Size([256, 256, 3, 3])
features.17.bias torch.Size([256])
features.18.weight torch.Size([256])
features.18.bias torch.Size([256])
features.18.running_mean torch.Size([256])
features.18.running_var torch.Size([256])
features.18.num_batches_tracked torch.Size([])
features.20.weight torch.Size([256, 256, 3, 3])
features.20.bias torch.Size([256])
features.21.weight torch.Size([256])
features.21.bias torch.Size([256])
features.21.running_mean torch.Size([256])
features.21.running_var torch.Size([256])
features.21.num_batches_tracked torch.Size([])
features.24.weight torch.Size([512, 256, 3, 3])
features.24.bias torch.Size([512])
features.25.weight torch.Size([512])
features.25.bias torch.Size([512])
features.25.running_mean torch.Size([512])
features.25.running_var torch.Size([512])
features.25.num_batches_tracked torch.Size([])
features.27.weight torch.Size([512, 512, 3, 3])
features.27.bias torch.Size([512])
features.28.weight torch.Size([512])
features.28.bias torch.Size([512])
features.28.running_mean torch.Size([512])
features.28.running_var torch.Size([512])
features.28.num_batches_tracked torch.Size([])
features.30.weight torch.Size([512, 512, 3, 3])
features.30.bias torch.Size([512])
features.31.weight torch.Size([512])
features.31.bias torch.Size([512])
features.31.running_mean torch.Size([512])
features.31.running_var torch.Size([512])
features.31.num_batches_tracked torch.Size([])
features.34.weight torch.Size([512, 512, 3, 3])
features.34.bias torch.Size([512])
features.35.weight torch.Size([512])
features.35.bias torch.Size([512])
features.35.running_mean torch.Size([512])
features.35.running_var torch.Size([512])
features.35.num_batches_tracked torch.Size([])
features.37.weight torch.Size([512, 512, 3, 3])
features.37.bias torch.Size([512])
features.38.weight torch.Size([512])
features.38.bias torch.Size([512])
features.38.running_mean torch.Size([512])
features.38.running_var torch.Size([512])
features.38.num_batches_tracked torch.Size([])
features.40.weight torch.Size([512, 512, 3, 3])
features.40.bias torch.Size([512])
features.41.weight torch.Size([512])
features.41.bias torch.Size([512])
features.41.running_mean torch.Size([512])
features.41.running_var torch.Size([512])
features.41.num_batches_tracked torch.Size([])
classifier.0.weight torch.Size([4096, 25088])
classifier.0.bias torch.Size([4096])
classifier.3.weight torch.Size([4096, 4096])
classifier.3.bias torch.Size([4096])
classifier.6.weight torch.Size([1000, 4096])
classifier.6.bias torch.Size([1000])

The output surprised me.
Why would I have 1000 activations in the last layer instead of 10?

I understand this is coming from ImageNet, and I know cnn_learner cuts the head of the network and replaces it with fully-connected layers with as many output activations as dls.c.

Still, assuming Learner does not do that, how can it even train properly?
Or maybe what is happening is that the 10 Imagenette classes are just mapped to the original 1000 ImageNet ones? e.g. the 900 left-out classes are kept in the output but in the end they are meaningless.

Learner does not chop off the head like you say, if you want similar behavior you have to chop off the head yourself. You can do this if your are familiar with pytorch. Learner is like driving manual, when cnn_learner is like driving with automatic.

When your example runs, this would mean that your dataset only uses labels 1-10 instead of the possible labels of 1-1000, and it is using the weight from imagenet for the head, which are generally thrown away. For example, I think label 1 is some type of fish. So yes 990 of the outputs are meaningless, and many of the weights in your linear layers are meaningless too.

Taken from Learner: self.loss_func(self.pred, *self.yb)
Loss is calculated using the above forumla. So if your batch size is 2 self.yb would be length 2, so lets say it is two numbers say 2 and 4, so what happens is: self.loss_func(self.pred, [2,4])

Quote from pytorch docs for CrossEntropyLoss:

This criterion expects a class index in the range [0,Cāˆ’1] as the target for each value of a 1D tensor of size minibatch; if ignore_index is specified, this criterion also accepts this class index (this index may not necessarily be in the class range).

So CrossEntropyLoss(similar to CrossEntropyLossFlat) expects a single number as input per image, and does not know how many classes are in your dataset, instead it assumes the output of your models determines the number of classes.

Going back to the formula: self.loss_func(self.pred, [2,4])
self.pred is a tensor of shape [bs,C]. In this case [2,1000] as 1000 is the number of classes in imagenet. So your label means [1,2] and [2,4] is the position of the correct values. Loss is then calculated assuming these are correct, and all other values in the prediction are wrong.

1 Like

Makes sense. Thanks a lot!