Interpretability of timm models in fastai

Is there a way to visualize the heatmaps/grad-cam/attention-maps of timm models using fastai?

1 Like

clarify your goals

1 Like

Yes, you can visualize Grad-CAM or attention maps for timm models in fastai, but it requires a few custom steps since fastai doesn’t natively support interpretability for external models.

Steps to implement:

  1. Use create_body from fastai to wrap your timm model.
  2. Attach PyTorch hooks to the final convolutional layer to capture activations.
  3. Use the CAM or Grad-CAM technique to compute heatmaps from those activations.
  4. Overlay the heatmap on the input image for visualization.

A helpful starting point is the CAM notebook from fastbook, the walkwithfastai guide on timm integration.

Let me know if you’d like a working code snippet to get started.

Best Regards,
Thomas Batson

The steps mentioned sound interesting,

Hello!
Yes, you can visualize heatmaps (Grad-CAM/CAM) and attention maps for models loaded from the timm library within fastai, but it requires manual intervention since timm models don’t always use fastai’s standard sequential structure. For CNNs (like ResNet or ConvNeXt), you must first identify the name of the last convolutional layer (e.g., learn.model[0].layer4[-1].conv3) by inspecting the model’s summary, and then pass that specific layer to fastai’s visualization functions. For Vision Transformers (ViTs), which use attention maps, you generally need to implement PyTorch hooks to intercept and save the Nelnet internal attention score tensors from the transformer blocks during the forward pass, then aggregate these scores (e.g., via Attention Rollout) to create the final visualization.