Callbacks are great for this - without messy globals and also handling removing the hooks automatically. I’ve just added some little classes to notebook 005 so you can simply do this:
This is HookCallback BTW:
class HookCallback(LearnerCallback):
def on_train_begin(self, **kwargs):
self.hooks = []
for name,module in learn.model.named_modules():
if list(module.children()) == []:
func = self.hook(name,module)
self.hooks.append(module.register_forward_hook(func))
def on_train_end(self, **kwargs):
if not self.hooks: return
for hook in self.hooks: hook.remove()
self.hooks=[]