Hi all,
I am working on an image classification (classify scenes/characters from various animated movies) for fun during the break and realize how important it is to understand and interpret what your model learns so you can make improvement on the model or on the dataset you have (so far it helps me build a solid validation set + fix a data collection error which could lead to data leakage). Based on Jeremy’s lesson 6 pet nb, henripal’s notebook and GradCam paper, I write some code to quickly generate gradcam and guided backprop, and would love to contribute it to fastai library.
Here are few examples
(Notation: Gradcam -> GC, Guided Backprop -> GBP
Images from left to right: original image / GC w.r.t Predicted label(with probability) / GBP w.r.t Predicted label / GC w.r.t Actual label (with probability) / GBP w.r.t Actual label)
Another example with only the gradcam w.r.t predicted label (programmable)
Here is the code / example and in short, you only need these lines of code.
# from ClassificationIntepretation object.
interp = ClassificationInterpretation.from_learner(learn,ds_type = DatasetType.Valid)
gcam = GradCam.from_interp(learner,interp,image_idx) #image_idx from ds.valid_ds or ds.test_ds
gcam.plot() #plot both GradCam and GuidedBackprop.
# You can also choose either one by passing parameters into plot function
This can be an addition to ClassificationInterpretation, e.g after interp.most_confused() and want to find out more about those most confused classes (example in the gist). You can also do gradcam + guided bp on test set by changing ds_type to DatasetType.Test (could be good for Kaggle competition)
You can also plot heatmap + gbp for 1 single Image object (see more in the notebook)
# from a single Image object.
img = open_image(path);
gcam = GradCam.from_one_img(learn,img)
gcam.plot()
Here is the project that I built where I use GradCam to visualize the model and troubleshoot: https://quantran.xyz/blog/building-an-image-classification-model-from-a-to-z/
Let me know if this looks good and I will refactor it a bit and do a PR