Train / Fine-Tune VGG-16 on Imagenette

Hey everyone!

I will be using the pre-trained VGG-16 (on Imagenet) from the PyTorch model zoo in order to fine-tune it on Imagenette. I will use this model later for some pruning and explainability use cases for my master thesis.

Any recommendations on how to do this ‘transfer-learning task’?
I thought of only training the classifier-layers (fully-connected layers at the end) in order to
(1) not overfit the network with Imagenette
(2) and as Imagenette is a subsample of Imagenet, I expect great transfer of knowledge.

What do you guys think? Any other recommendations?
Happy to hear some thoughts!

As you’ve mentioned, Imagenette is a small subset of ImageNet, so you would likely be fine with merely training the head. For slightly better performance, you could train the entire model for a few epochs.

Be wary though: I’m not sure how Imagenette is split into training and validation (i.e. does Imagenette’s validation set overlap with ImageNet’s training set?), but I’m 99% certain Jeremy has already thought of that and divided Imagenette accordingly.

Thanks for your input!
Good catch - I have not thought about the train/valid split. I will keep an eye on that.

so you would likely be fine with merely training the head.

What do you mean with ‘head’? My approach would have been to just train the last output layer of VGG. Would you consider training all FC-Layers?

Head ≈ fully-connected layers :slight_smile:

Thanks for clarifying!
Will definitely consider training all fc-layers as well! :slight_smile:

No problem! Feel free to ask any questions I could help with.

1 Like

@BobMcDear So - finetuning the VGG16 only on the output layer with 1 epoch resulted in an Accuracy of 97% (Training) and 98.8% (Validation).

Better results than expected - especially since it was only trained for 1 epoch.
What do you think?

Huh, that’s weird. I ran a quick test and got a training accuracy of 97% and validation accuracy of 84%. Our implementation details likely differ, but it shouldn’t cause that big a gap between our scores.

Would you please post your code, or at least the part dealing with the data? I’m guessing this is a case of data leakage (maybe you’re validating on the training set, which the pre-trained model was trained on?).

Thanks!

Sure!

Loading Imagenette:

 def load_imagenette(batch_size):
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # Imagenet Values
        ])

        trainset = datasets.ImageFolder('data/imagenette2/train', transform=transform)
        validset = datasets.ImageFolder('data/imagenette2/val', transform=transform)

        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
        validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size, shuffle=True)

        return trainloader, validloader

I use VGG-16 pretrained=True, freeze all layers, except the output layer which got changed to have output = 10 classes.

Other infos:

batch_size = 64
criterion = CrossEntropyLoss
optimizer = SGD with lr = 0.01, momentum= 0.9

Training Function:

def train_model(model, trainloader, validloader, num_epochs, optimizer, criterion, device):
          starting_time = time.time()
            val_acc_history = []
            best_model_wts = copy.deepcopy(model.state_dict())
            best_train_acc = 0.0
            best_val_acc = 0.0
     for epoch in range(num_epochs):
            print('Epoch: {}/{}'.format(epoch+1, num_epochs))
            print('-' * 10)

            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()   # set model to training mode
                    dataloader = trainloader
                else:
                    model.eval()    # set model to eval mode
                    dataloader = validloader

                running_loss = 0
                running_correct = 0

                for images, labels in dataloader:
                    images = images.to(device)
                    labels = labels.to(device)

                    # zero out gradients
                    optimizer.zero_grad()

                    # forward
                    # track when in training
                    with torch.set_grad_enabled(phase == 'train'):
                        outs = model(images)
                        loss = criterion(outs, labels)

                        _, preds = torch.max(outs.data, 1)

                        # backward & optimize when in training
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    running_loss += loss.item()
                    running_correct += (preds == labels).sum().item()

                epoch_loss = running_loss / len(dataloader.dataset)
                epoch_acc = running_correct / len(dataloader.dataset)

                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

                if phase == 'train' and epoch_acc > best_train_acc:
                    best_train_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                if phase == 'val':
                    val_acc_history.append(epoch_acc)
                    if epoch_acc > best_val_acc:
                        best_val_acc = epoch_acc

        time_elapsed = time.time() - starting_time
        time_elapsed = "{:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)
        print('Training complete in {}'.format(time_elapsed))

        model.load_state_dict(best_model_wts)

        return model, best_train_acc, best_val_acc, val_acc_history, time_elapsed

If you need more code, let me know - kinda hard to share as I tried to built it modularly to reuse parts of it.
Thanks!

The issue I could think of, is that the Imagenette Validation set is actually part of the Imagenet Train Set

Your code seems completely fine, and I ran a few other tests and was able to reach/exceed a validation accuracy of 98.8%.

There are thus two possible scenarios:

  1. Imagenette’s validation set overlaps with ImageNet’s training set
  2. Imagenette is just too easy :slightly_smiling_face:

Again, I’m pretty sure the core contributors of Imagenette, Imagewoof, and Image网 have ensured the former isn’t the case (there is a slight chance it is though), and the latter is more probable: A few epochs with no pre-training can easily get you a 90% validation accuracy, and your model is pre-trained on millions of images, which include your training set.

My advice is, to not have to worry about the overlap issue, just train & validate your model on another dataset (Oxford Pet, Oxford Flowers, etc.), particularly since you mentioned your use case is pruning and interpretability, so that should be no problem.

Let me know if I could help!

Thanks for your help and testing my code! :slight_smile:

When I have cloud-resources available (I have no dedicated GPU - hence only 1 epoch of training) I will also train VGG from scratch with Imagenette - just to see.

However, in the end I don’t really care how good the model performs. From here I will start to implement a few pruning algorithms and assess their effect on a chosen interpretability method. As long as I have a decent model that I can further prune and fine-tune iteratively.

Thanks again! Very much appreciated.

Always happy to help!

P.S: Though not a viable long-term option, Colab can be a great tool to perform small experiments, and training on Imagenette should take no more than a minute or so on it.