FlashTorch - feature visualisation toolkit in PyTorch

Hi all,

I’ve been fascinated by feature visualisation techniques.

There was no tool available to easily apply these techniques to networks I’ve built in PyTorch, so I created one - FlashTorch :flashlight:

I wrote a blog post to introduce:

  • Feature visualisation as an area of research
  • How to use FlashTorch to create saliency maps from AlexNet

Here is the GitHub repo for source code and example notebooks. I’ve also created a Jupyter notebook hosted on Google Colab, so you can try it out without installing anything!

My hope is that FlashTorch will help make your CNN projects more interpretable and explainable.

It is very much work in progress, and I would love to have your constructive feedback, comments and suggestions.

Thanks so much in advance.


Nice work.

1 Like


Hi, does it work with FastAI? Can you help me with it?


Thank you so much
@misaogura I have some error with the Colab notebook 404.
thank you so much for FlashTorch, it seems pretty useful,
I am having hard time to understand the import of the model
with format is .pt? can I import .pkl format?

Congratulation, awesome work

I almost made it work with fast.ai, but I can’t manage to get guided gradients to work.
Here is what I have (keep in mind this is using fastai2).

from flashtorch.utils import apply_transforms, load_image, visualize
from flashtorch.saliency import Backprop
backprop = Backprop(learn.model)
image = load_image('/home/jupyter/course-v4/nbs/food101/data/food-101/food-101/images/apple_pie/416233.jpg')
# Transform the input image to a tensor
img = owl = apply_transforms(image)
target_class = 0 # add getting class from dls.sth.vocab
backprop.visualize(owl, target_class, guided=True, use_gpu=True) # this works

gradients = backprop.calculate_gradients(img, target_class) 
# this doesn't work because I get this error - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

max_gradients = backprop.calculate_gradients(img, target_class, take_max=True)
visualize(img, gradients, max_gradients)

If someone knows how can I make it work that’d be great.

Anyone find how to integrate it with fastai.