Using a pretrained ViT from Timm

Hi guys, I am new to modifying the fastai code and currently trying to use the pretrained Vision Transformer from Timm.

However, the transformer doesn’t have a pooling layer and so the old code for creating head is not applicable, so what I do is to use the following code.

class ViTBase16(nn.Module):
def __init__(self, n_classes, pretrained=True):

    super(ViTBase16, self).__init__()
    self.model = timm.create_model("vit_base_patch16_224", pretrained=True)
    self.model.head = nn.Linear(self.model.head.in_features, n_classes)

def forward(self, x):
    x = self.model(x)
    return x

But the model is untrainable and all frozen even if I unfreeze it.

unfrozen_params = filter(lambda p: not p.requires_grad, learn.model.parameters()) unfrozen_params = sum([np.prod(p.size()) for p in unfrozen_params]) model_parameters = filter(lambda p: p.requires_grad, learn.model.parameters()) frozen_params = sum([np.prod(p.size()) for p in model_parameters])

unfrozen_params, frozen_params
(0, 85802501)

Any Idea how to make it trainable again? Also should I make a larger custom head?
Thanks guys.

4 Likes

@kachun1017 did you ever get this to work? I was just searching for the same thing!

3 Likes

I have tried this code and it works for me.

model = timm.create_model(“vit_base_patch16_224”, pretrained=True)
for param in model.parameters():
        param.requires_grad = False
outputs_attrs = n_classes
num_inputs = model.head.in_features
last_layer = nn.Linear(num_inputs, outputs_attrs)
model.head = last_layer

7 Likes

Could you please share your notebook

Thank you for sharing the code! Just a doubt here tho, I use pytorch with timm, it took 30 min to train 1 epoch, while fastai took only 4 min to train 1 epoch, on same dataset, with same batch size. I wonder if that’s normal. The GPU memory also took 75% less than that of pytorch.

i have the same problem ! any solution please ?

Wrapper Class Solution

According to the timm source code, VisionTransformer#forward() only executes forward_feature() and forward_head() in sequence. With Disabling last linear layer by setting pre_logits = True, you can get the embed.

Wrapper Class Example

class VisionTransformerLogit():
    def __init__(self, variant='vit_small_r26_s32_384', pretrained=True):
        self.model = timm.create_model(variant, pretrained=pretrained)
    
    def forward(self, x):
        x = self.model.forward_features(x)
        x = self.model.forward_head(x, pre_logits=True) 
        return x

Usage

body = VisionTransformerLogit('vit_small_r26_s32_384', pretrained=True)
emb = body.forward(x)
emb.shape, emb

# outputs
> (torch.Size([1, 384]),
>  TensorImage([[ 5.7661e+00, -4.5164e-01, -1.1806e+00, -2.3579e+00,  1.2144e+00,
>           -1.3464e-01,  3.4350e+00,  1.8872e+00,  1.7765e+00, -4.7565e-01,
>            2.3307e+00, -1.0278e+00, -2.3151e+00,  1.7306e+00,  3.7990e-01,
>            2.7964e+00, -1.0973e+00,  1.5270e+00, -2.1416e-01,  1.0004e+00,
> ...