Using a pretrained ViT from Timm

Hi guys, I am new to modifying the fastai code and currently trying to use the pretrained Vision Transformer from Timm.

However, the transformer doesn’t have a pooling layer and so the old code for creating head is not applicable, so what I do is to use the following code.

class ViTBase16(nn.Module):
def __init__(self, n_classes, pretrained=True):

    super(ViTBase16, self).__init__()
    self.model = timm.create_model("vit_base_patch16_224", pretrained=True)
    self.model.head = nn.Linear(self.model.head.in_features, n_classes)

def forward(self, x):
    x = self.model(x)
    return x

But the model is untrainable and all frozen even if I unfreeze it.

unfrozen_params = filter(lambda p: not p.requires_grad, learn.model.parameters()) unfrozen_params = sum([ for p in unfrozen_params]) model_parameters = filter(lambda p: p.requires_grad, learn.model.parameters()) frozen_params = sum([ for p in model_parameters])

unfrozen_params, frozen_params
(0, 85802501)

Any Idea how to make it trainable again? Also should I make a larger custom head?
Thanks guys.


@kachun1017 did you ever get this to work? I was just searching for the same thing!


I have tried this code and it works for me.

model = timm.create_model(“vit_base_patch16_224”, pretrained=True)
for param in model.parameters():
        param.requires_grad = False
outputs_attrs = n_classes
num_inputs = model.head.in_features
last_layer = nn.Linear(num_inputs, outputs_attrs)
model.head = last_layer


Could you please share your notebook

Thank you for sharing the code! Just a doubt here tho, I use pytorch with timm, it took 30 min to train 1 epoch, while fastai took only 4 min to train 1 epoch, on same dataset, with same batch size. I wonder if that’s normal. The GPU memory also took 75% less than that of pytorch.

i have the same problem ! any solution please ?