An error message arise after
vgg = models.vgg.vgg16()
vgg.load_state_dict(torch.load(file))
KeyError: ‘unexpected key “classifier.1.weight” in state_dict’
I found a solution at https://github.com/jcjohnson/pytorch-vgg/issues/3. A cell with the following code needs to be added:
sd = torch.load(file)
sd['classifier.0.weight'] = sd['classifier.1.weight']
sd['classifier.0.bias'] = sd['classifier.1.bias']
del sd['classifier.1.weight']
del sd['classifier.1.bias']
sd['classifier.3.weight'] = sd['classifier.4.weight']
sd['classifier.3.bias'] = sd['classifier.4.bias']
del sd['classifier.4.weight']
del sd['classifier.4.bias']
Then
vgg = models.vgg.vgg16()
vgg.load_state_dict(sd)