Hey, thank you for your reply. I didn’t have the code at hand anymore, but you re totally right about adding code, below is an example. Just use any model written with " " around the name (therefore the model is then used from the timm library) and you have the problem as shown below. I “solved” it by using grad cam instead of using cam (class activation maps) for the wanted location of the output, but I still cannot use grad cam this way for “any” location within the body of the model (within the timm body).
Example from the fastbook I used for grad cam:
class Hook():
def __init__(self, m):
self.hook = m.register_forward_hook(self.hook_func)
def hook_func(self, m, i, o): self.stored = o.detach().clone()
def __enter__(self, *args): return self
def __exit__(self, *args): self.hook.remove()
class HookBwd():
def __init__(self, m):
self.hook = m.register_backward_hook(self.hook_func)
def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
def __enter__(self, *args): return self
def __exit__(self, *args): self.hook.remove()
for index, i in enumerate(range(len(dls.valid_ds))):
x, = first(dls.test_dl([dls.valid_ds[i][0]]))
cls = 1
with HookBwd(learn.model[0]) as hookg:
with Hook(learn.model[0]) as hook:
output = learn.model.eval()(x.cuda())
act = hook.stored
output[0,cls].backward()
grad = hookg.stored
w = grad[0].mean(dim=[1,2], keepdim=True)
cam_map = (w * act[0]).sum(0)
x_dec = TensorImage(dls.valid.decode((x,))[0][0])
_,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
interpolation='bilinear', cmap='magma');
_,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
interpolation='bilinear', cmap='magma');
code example that creates the mentioned problem:
path = untar_data(URLs.PETS)/'images'
def is_cat(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path), valid_pct=0.2, seed=21,
label_func=is_cat, item_tfms=Resize(256))
learn = vision_learner(dls, "convnext_nano.in12k_ft_in1k", metrics=error_rate)
learn.fine_tune(0)
img =PILImage.create(get_image_files('/media/testset_images')[0])
x, = first(dls.test_dl([img]))
hook_output = Hook()
hook = learn.model[0][5].register_forward_hook(hook_output.hook_func)
with torch.no_grad(): output = learn.model.eval()(x)
act = hook_output.stored[0]
act.shape
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In [8], line 6
2 x, = first(dls.test_dl([img]))
5 hook_output = Hook()
----> 6 hook = learn.model[0][5].register_forward_hook(hook_output.hook_func)
7 with torch.no_grad(): output = learn.model.eval()(x)
8 act = hook_output.stored[0]
TypeError: 'TimmBody' object is not subscriptable