I have tried this code and it works for me.
model = timm.create_model(“vit_base_patch16_224”, pretrained=True)
for param in model.parameters():
param.requires_grad = False
outputs_attrs = n_classes
num_inputs = model.head.in_features
last_layer = nn.Linear(num_inputs, outputs_attrs)
model.head = last_layer