Resolving import error and configure Captum for multi-label classification in fastai

Hi,

I am trying to learn the Captum for interpretability. Thanks to the Jeremy and fastai development team it has been integrated with the library. I am using the steps explained in the following library docs: fastai - Captum, but unfortunately, it is not working as described. I keep getting the error that

NameError: name ‘CaptumInterpretation’ is not defined

I tried importing it explicitly using from fastai.callback.captum.all import *, it gives the error ModuleNotFoundError: No module named 'captum'. I checked the captum.py file inside the mamba directory and it is present.

Any hint of what might be wrong with this?

Thanks in advance

Kind regards,
Bilal

Have you installed captum? I needed to install it along with flask and flask-compress, so run:

mamba install -c pytorch captum flask flask-compress

or the pip equivalent.
Then

from fastai.callback.captum import *

should work.

Edit: So the module not found error refers to the captum module that is used to build the fastai captum callback…

2 Likes

Thanks benkarr, it is working now. Imports are going through.

Screenshot from 2022-11-08 19-50-42

However, the call to ‘captum.visualize(f)’ now raises error:


TypeError Traceback (most recent call last)
Cell In [31], line 1
----> 1 captum.visualize(f)

File ~/mambaforge/lib/python3.10/site-packages/fastai/callback/captum.py:54, in CaptumInterpretation.visualize(self, inp, metric, n_steps, baseline_type, nt_type, strides, sliding_window_shapes)
52 raise Exception(f"Metric {metric} is not supported. Currently {self.supported_metrics} are only supported")
53 tls = L([TfmdLists(inp, t) for t in L(ifnone(self.dls.tfms,[None]))])
—> 54 inp_data=list(zip(*(tls[0],tls[1])))[0]
55 enc_data,dec_data=self._get_enc_dec_data(inp_data)
56 attributions=self._get_attributions(enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes)

File ~/mambaforge/lib/python3.10/site-packages/fastai/data/core.py:376, in (.0)
→ 376 def iter(self): return (self[i] for i in range(len(self)))

File ~/mambaforge/lib/python3.10/site-packages/fastai/data/core.py:414, in TfmdLists.getitem(self, idx)
412 res = super().getitem(idx)
413 if self._after_item is None: return res
→ 414 return self._after_item(res) if is_indexer(idx) else res.map(self._after_item)

File ~/mambaforge/lib/python3.10/site-packages/fastai/data/core.py:374, in TfmdLists._after_item(self, o)
→ 374 def _after_item(self, o): return self.tfms(o)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:208, in Pipeline.call(self, o)
→ 208 def call(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:158, in compose_tfms(x, tfms, is_enc, reverse, **kwargs)
156 for f in tfms:
157 if not is_enc: f = f.decode
→ 158 x = f(x, **kwargs)
159 return x

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:81, in Transform.call(self, x, **kwargs)
—> 81 def call(self, x, **kwargs): return self._call(‘encodes’, x, **kwargs)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:91, in Transform._call(self, fn, x, split_idx, **kwargs)
89 def _call(self, fn, x, split_idx=None, **kwargs):
90 if split_idx!=self.split_idx and self.split_idx is not None: return x
—> 91 return self._do_call(getattr(self, fn), x, **kwargs)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:97, in Transform.do_call(self, f, x, **kwargs)
95 if f is None: return x
96 ret = f.returns(x) if hasattr(f,‘returns’) else None
—> 97 return retain_type(f(x, **kwargs), x, ret)
98 res = tuple(self.do_call(f, x, **kwargs) for x
in x)
99 return retain_type(res, x)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/dispatch.py:120, in TypeDispatch.call(self, *args, **kwargs)
118 elif self.inst is not None: f = MethodType(f, self.inst)
119 elif self.owner is not None: f = MethodType(f, self.owner)
→ 120 return f(*args, **kwargs)

Cell In [4], line 6, in get_x(r)
----> 6 def get_x(r): return r[‘image_id’]

TypeError: string indices must be integers

It seems ‘captum’ is trying to call the data block method ‘get_x’ method which it shouldn’t as it is passed the path to image file.

Any idea how to resolve this error?

I guess you use a dataframe as a source and your dataloaders are set up for that.

Cell In [4], line 6, in get_x(r)
----> 6 def get_x(r): return r[‘image_id’]

tries to access the image_id column of that dataframe which doesn’t work since you are passing the path directly.

Try:

idx=randint(0,len(dls.valid.items))
f = dls.valid.items.iloc[[idx]]

captum=CaptumInterpretation(learn)
captum.visualize(f)

Note the double brackets in f = dls.valid.items.iloc[[idx]] this is necessary to retreive a single-row pd.DataFrame (where f = dls.valid.items.iloc[idx] returns a pd.Series).

Hope it works now :slight_smile:

2 Likes

Thanks. it worked. Now the code is picking a row from the data frame.

But now stumbled upon ‘wrong size in the last dimension error’. Please check below:


RuntimeError Traceback (most recent call last)
Cell In [7], line 5
2 f = dls.valid.items.iloc[[idx]]
4 captum=CaptumInterpretation(learn)
----> 5 captum.visualize(f)

File ~/mambaforge/lib/python3.10/site-packages/fastai/callback/captum.py:55, in CaptumInterpretation.visualize(self, inp, metric, n_steps, baseline_type, nt_type, strides, sliding_window_shapes)
53 tls = L([TfmdLists(inp, t) for t in L(ifnone(self.dls.tfms,[None]))])
54 inp_data=list(zip(*(tls[0],tls[1])))[0]
—> 55 enc_data,dec_data=self._get_enc_dec_data(inp_data)
56 attributions=self._get_attributions(enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes)
57 self._viz(attributions,dec_data,metric)

File ~/mambaforge/lib/python3.10/site-packages/fastai/callback/captum.py:73, in CaptumInterpretation._get_enc_dec_data(self, inp_data)
71 def _get_enc_dec_data(self,inp_data):
72 dec_data=self.dls.after_item(inp_data)
—> 73 enc_data=self.dls.after_batch(to_device(self.dls.before_batch(dec_data),self.dls.device))
74 return(enc_data,dec_data)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:208, in Pipeline.call(self, o)
→ 208 def call(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:158, in compose_tfms(x, tfms, is_enc, reverse, **kwargs)
156 for f in tfms:
157 if not is_enc: f = f.decode
→ 158 x = f(x, **kwargs)
159 return x

File ~/mambaforge/lib/python3.10/site-packages/fastai/vision/augment.py:49, in RandTransform.call(self, b, split_idx, **kwargs)
43 def call(self,
44 b,
45 split_idx:int=None, # Index of the train/valid dataset
46 **kwargs
47 ):
48 self.before_call(b, split_idx=split_idx)
—> 49 return super().call(b, split_idx=split_idx, **kwargs) if self.do else b

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:81, in Transform.call(self, x, **kwargs)
—> 81 def call(self, x, **kwargs): return self._call(‘encodes’, x, **kwargs)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:91, in Transform._call(self, fn, x, split_idx, **kwargs)
89 def _call(self, fn, x, split_idx=None, **kwargs):
90 if split_idx!=self.split_idx and self.split_idx is not None: return x
—> 91 return self._do_call(getattr(self, fn), x, **kwargs)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:98, in Transform.do_call(self, f, x, **kwargs)
96 ret = f.returns(x) if hasattr(f,‘returns’) else None
97 return retain_type(f(x, **kwargs), x, ret)
—> 98 res = tuple(self.do_call(f, x, **kwargs) for x
in x)
99 return retain_type(res, x)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:98, in (.0)
96 ret = f.returns(x) if hasattr(f,‘returns’) else None
97 return retain_type(f(x, **kwargs), x, ret)
—> 98 res = tuple(self.do_call(f, x, **kwargs) for x_ in x)
99 return retain_type(res, x)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/transform.py:97, in Transform.do_call(self, f, x, **kwargs)
95 if f is None: return x
96 ret = f.returns(x) if hasattr(f,‘returns’) else None
—> 97 return retain_type(f(x, **kwargs), x, ret)
98 res = tuple(self.do_call(f, x, **kwargs) for x
in x)
99 return retain_type(res, x)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/dispatch.py:120, in TypeDispatch.call(self, *args, **kwargs)
118 elif self.inst is not None: f = MethodType(f, self.inst)
119 elif self.owner is not None: f = MethodType(f, self.owner)
→ 120 return f(*args, **kwargs)

File ~/mambaforge/lib/python3.10/site-packages/fastai/vision/augment.py:501, in AffineCoordTfm.encodes(self, x)
→ 501 def encodes(self, x:TensorImage): return self._encode(x, self.mode)

File ~/mambaforge/lib/python3.10/site-packages/fastai/vision/augment.py:499, in AffineCoordTfm._encode(self, x, mode, reverse)
497 def _encode(self, x, mode, reverse=False):
498 coord_func = None if len(self.coord_fs)==0 or self.split_idx else partial(compose_tfms, tfms=self.coord_fs, reverse=reverse)
→ 499 return x.affine_coord(self.mat, coord_func, sz=self.size, mode=mode, pad_mode=self.pad_mode, align_corners=self.align_corners)

File ~/mambaforge/lib/python3.10/site-packages/fastai/vision/augment.py:391, in affine_coord(x, mat, coord_tfm, sz, mode, pad_mode, align_corners)
389 coords = affine_grid(mat, x.shape[:2] + size, align_corners=align_corners)
390 if coord_tfm is not None: coords = coord_tfm(coords)
→ 391 return TensorImage(_grid_sample(x, coords, mode=mode, padding_mode=pad_mode, align_corners=align_corners))

File ~/mambaforge/lib/python3.10/site-packages/fastai/vision/augment.py:364, in _grid_sample(x, coords, mode, padding_mode, align_corners)
362 if d>1 and d>z:
363 x = F.interpolate(x, scale_factor=1/d, mode=‘area’, recompute_scale_factor=True)
→ 364 return F.grid_sample(x, coords, mode=mode, padding_mode=padding_mode, align_corners=align_corners)

File ~/mambaforge/lib/python3.10/site-packages/torch/nn/functional.py:4197, in grid_sample(input, grid, mode, padding_mode, align_corners)
4097 r""“Given an :attr:input and a flow-field :attr:grid, computes the
4098 output using :attr:input values and pixel locations from :attr:grid.
4099
(…)
4194 … _OpenCV: opencv/resize.cpp at f345ed564a06178670750bad59526cfa4033be55 · opencv/opencv · GitHub
4195 “””
4196 if has_torch_function_variadic(input, grid):
→ 4197 return handle_torch_function(
4198 grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners
4199 )
4200 if mode != “bilinear” and mode != “nearest” and mode != “bicubic”:
4201 raise ValueError(
4202 "nn.functional.grid_sample(): expected mode to be "
4203 “‘bilinear’, ‘nearest’ or ‘bicubic’, but got: ‘{}’”.format(mode)
4204 )

File ~/mambaforge/lib/python3.10/site-packages/torch/overrides.py:1534, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
1528 warnings.warn("Defining your __torch_function__ as a plain method is deprecated and " 1529 "will be an error in future, please define it as a classmethod.", 1530 DeprecationWarning) 1532 # Use public_apiinstead ofimplementation` so torch_function
1533 # implementations can do equality/identity comparisons.
→ 1534 result = torch_func_method(public_api, types, args, kwargs)
1536 if result is not NotImplemented:
1537 return result

File ~/mambaforge/lib/python3.10/site-packages/fastai/torch_core.py:378, in TensorBase.torch_function(cls, func, types, args, kwargs)
376 if cls.debug and func.name not in (‘str’,‘repr’): print(func, types, args, kwargs)
377 if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
→ 378 res = super().torch_function(func, types, args, ifnone(kwargs, {}))
379 dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))
380 if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)

File ~/mambaforge/lib/python3.10/site-packages/torch/_tensor.py:1278, in Tensor.torch_function(cls, func, types, args, kwargs)
1275 return NotImplemented
1277 with _C.DisableTorchFunction():
→ 1278 ret = func(*args, **kwargs)
1279 if func in get_default_nowrap_functions():
1280 return ret

File ~/mambaforge/lib/python3.10/site-packages/torch/nn/functional.py:4235, in grid_sample(input, grid, mode, padding_mode, align_corners)
4227 warnings.warn(
4228 "Default grid_sample and affine_grid behavior has changed "
4229 "to align_corners=False since 1.3.0. Please specify "
4230 "align_corners=True if the old behavior is desired. "
4231 “See the documentation of grid_sample for details.”
4232 )
4233 align_corners = False
→ 4235 return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)

RuntimeError: grid_sampler(): expected grid to have size 1 in last dimension, but got grid with sizes [3, 90, 160, 2]

Any idea how to get through this?

Ok,

it seems that you are using a model that I’m not familiar with, which is why I probably can no longer provide specific help, but feel free to share more of your code/process.

In particular I can not follow along at this point in the stack trace:

—> 73 enc_data=self.dls.after_batch(to_device(self.dls.before_batch(dec_data),self.dls.device))

since I don’t know what happens at .before_batch and .after_batch in your Learner, but those would be the places where I’d start investigating!
Something that could go wrong here is that one of the .after_batch callbacks requires the forward pass to happen (or any other callback that would happen earlier in the train/eval loop but isn’t called by .visualize).

Also: did you try the docs example without modifications? If you haven’t I would check that to rule out installation issues.

1 Like

Hi Ben,

Thanks for such thoughtful debugging steps.

I tried running the fastai tutorial first. The tutorial is missing the import from fastai.callback.captum import *. Aside, it worked. See the output:

I got the error on the following command. Do you think this might be installation issue.

The specification of the data block and other function used to train the model is below:

I tried using the learn.dls.train_dl([f]) to see if the same augmentations can be applied but it is giving the following error:

No idea how to apply the same augmentations with data frame as the data source.

Tl;dr Captum isn’t made for multi-class classification, but there might be a workaround.

First the easy issue:

No, this is actually a bug in the callbacks code. The n_samples parameter in self._noise_tunnel.attribute should be named nt_samples, you can use this patch:

from captum.attr import IntegratedGradients,NoiseTunnel,GradientShap,Occlusion
@patch
def _get_attributions(self:CaptumInterpretation,enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes):
        # Get Baseline
        baseline=self.get_baseline_img(enc_data[0],baseline_type)
        supported_metrics ={}
        if metric == 'IG':
            self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)
            return self._int_grads.attribute(enc_data[0],baseline, target=enc_data[1], n_steps=200)
        elif metric == 'NT':
            self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)
            self._noise_tunnel= self._noise_tunnel if hasattr(self,'_noise_tunnel') else NoiseTunnel(self._int_grads)
            return self._noise_tunnel.attribute(enc_data[0].to(self.dls.device), nt_samples=1, nt_type=nt_type, target=enc_data[1])
        elif metric == 'Occl':
            self._occlusion = self._occlusion if hasattr(self,'_occlusion') else Occlusion(self.model)
            return self._occlusion.attribute(enc_data[0].to(self.dls.device),
                                       strides = strides,
                                       target=enc_data[1],
                                       sliding_window_shapes=sliding_window_shapes,
                                       baselines=baseline)

I used this notebook to reproduce the original error with a minimal example similar to your setup; the nb also contains the workarounds:

RuntimeError: grid_sampler(): expected grid to have size 1 in last dimension, but got grid with sizes [3, 90, 160, 2]

This seems to be produced by the augmentations, in particular Flip. I don’t know why but I also don’t think that this augmentation is needed for the interpretation… so I just removed it from the transformation pipeline:

learn.dls.after_batch = Pipeline(funcs=[t for i,t in enumerate(learn.dls.after_batch) if i!=1])

But this produces then next error which is the real issue imo:

AssertionError: Tensor target dimension torch.Size([4000]) is not valid. torch.Size([200, 20])

I don’t understand enough about the library that I could explain what happens here, but apparently the fastai callback passes the predicted MultiCategory tensor to Captum which expects a single integer (in this case… I guess…).
According to this issue MultiCategory classification is not supported directly by Captum but you can get the interpretations for each individual class by passing a single target_idx or by combining multiple interpretations. You can find patches for both in my notebook. This is very “hacky” and taking the average of multiple attributions also doesn’t look as if it does the right thing, but at least you have something to look at :laughing:
If you want to have more convenience, like passing a target_id to .visualize instead of hardcoding it to the patch or use the average for different metrics, you have to add this yourself I guess, since unfortunately I procrastinated enough on this :laughing:

I hope this gets you closer to the solution :slightly_smiling_face:

1 Like

Dear Ben,

You are a superstar :clap:. Thank you so much. Eventually made it through by following your instructions. See the output below.

I can make those tweaks to target_idx to suit my use case as you suggested.

It just look incredible. I can investigate areas informing the models decision. So excited to use it well. More power to the elbow.

Proud to be part of such a great community.

Many thanks again and
Best Regards,
Bilal

1 Like

Yes, that looks great!
Happy that it worked and even more happy that you seem to put it to good use :blush:

1 Like

Have you installed captum? I needed to install it along with flask and flask-compress, so run:

mamba install -c pytorch captum flask flask-compress

Hi Ben, thanks! But is this somewhere stated in the docs? I couldn’t find anything… how can people know this without coming to this forum and seeng your great answer ?:slight_smile: