Confused by output of "hook_output"

I’m confused about the number of outputs showing up via hook.stored. I would suspect they would be the same as the validation set, but they are not.

data = ImageDataBunch.from_folder(PATH, ds_tfms=get_transforms(), size=img_sz, bs=bsz).normalize(imagenet_stats)
learn = create_cnn(data, models.resnet34, metrics=[accuracy])

nn_module = learn.model[-1][-3] # BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

hook = hook_output(nn_module)
learn.fit_one_cycle(1)

len(hook.stored), len(data.train_ds), len(data.valid_ds)
# => (71, 1822, 455)

Why are the number of outputs 71???

1 Like

What is your batch size bsz?

Batch size = 64

SOLVED …

So @lesscomfortable was thinking on the right track when asking about batch size as hooks only store the last batch worth of data in hook.stored.

Here’s a couple of ways by which you could implement the hook described above and discussed in much more detail here.

Option 1

class StoreHook(HookCallback):
    def on_train_begin(self, **kwargs):
        super().on_train_begin(**kwargs)
        self.outputs = []
        
    def hook(self, m, i, o): 
        return o
    
    def on_batch_end(self, train, **kwargs): 
        if (not train): self.outputs.append(self.hooks.stored[0])

data = ImageDataBunch.from_folder(PATH, ds_tfms=get_transforms(), size=img_sz, bs=bsz).normalize(imagenet_stats)
learn = create_cnn(data, models.resnet34, metrics=[error_rate])

# the last 2 layers are Dropout and Linear (for predictions)
nn_module = learn.model[-1][-3] # BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

learn.callbacks += [ StoreHook(learn, modules=flatten_model(nn_module)) ]
learn.fit_one_cycle(1)

cb = learn.callbacks[0]
torch.cat(cb.outputs).shape, len(data.valid_ds)
# => torch.Size([455, 512]), 455

Option 2

class StoreHook2(Callback):
    def __init__(self, module):
        super().__init__()
        self.custom_hook = hook_output(module)
        self.outputs = []
        
    def on_batch_end(self, train, **kwargs): 
        if (not train): self.outputs.append(self.custom_hook.stored)

data = ImageDataBunch.from_folder(PATH, ds_tfms=get_transforms(), size=img_sz, bs=bsz).normalize(imagenet_stats)
learn = create_cnn(data, models.resnet34, metrics=[error_rate])

# the last 2 layers are Dropout and Linear (for predictions)
nn_module = learn.model[-1][-3] # BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

learn.callbacks += [ StoreHook2(nn_module) ]
learn.fit_one_cycle(1)

cb = learn.callbacks[0]
torch.cat(cb.outputs).shape, len(data.valid_ds)
# => torch.Size([455, 512]), 455
``
9 Likes

Oh yeah, and since it is the validation set the batch size is 2*bs and the last batch is usually smaller than the preceeding ones (since the length of the validation set is generally not divisible exactly be 2*bs) it makes sense that your last hook has length 71. I was puzzled by that number.

1 Like

Useful. Thanks for sharing.

What are m, i, and o? The function just returns o, without making any use of the other two…

m = the module
i = the input Tensor
o = the output Tensor

2 Likes

Thank you @wgpubs . Your code is very useful. Can I ask you one question ?

    def on_batch_end(self, train, **kwargs): 
        if (not train): self.outputs.append(self.custom_hook.stored)

What is the train argument in your function. Why if (not train) ? I saw it in several callbacks but don’t know what it is and how it is set ?

Thank you in advance,

This is there to capture the outputs for just the validation set.

If you wanted to capture the outputs for just training set, you would do the opposite… if you wanted both, you wouldn’t even have the if statement.

1 Like

Thanks for providing your examples! This helped a lot in my struggle to understand how to use hooks in fastai library.

Hi,
I came across this error while using your code snippet:

AttributeError: 'super' object has no attribute '_StoreHook__init'

Could you point out what I am doing wrong here?
Is it necessary to call super().__init()?
I don’t get the error if I remove it.

Something has probably changed in the API since I wrote my post … take a look and if you can figure out a fix, post it here and I’ll be glad to update my instructions.