Can I use a hook inside the forward function of a module?

Can we use the result of a hook in the forward of a module?

For example I have a Module where I pass in a pre-trained model named questions. questions is a TabularModel, but I want the result of a Linear layer inside of it to do more computations with it:

class ParentChildModel(Module):
    def __init__(self, questions):
        self.results = TabularModel(results_emb_szs, len(results_cont_names), 10, [200, 50], ps=[0.01, .1], embed_p=0.04, bn_final=False)
        self.questions = questions
        
        self.head = nn.Sequential(*[LinBnDrop(20, 1, p=0.1), SigmoidRange(*[-.1, 1.1])])

    def forward(self, data, children):
            [... some code here...]
            # self.questions.layers[-3][2] is a Linear layer
            with hook_output(self.questions.layers[-3][2], detach=False) as h:
                result = self.questions(cat, cont)
                mid = h.stored

Then I use mid variable in some other calculations later. I am not getting really good results with that so I was wondering if this is valid and if the gradients will flow well.

The hook is detached as soon as you leave the context manager, so I’m not sure you should use one. Use the way DynamicUnet is implemented instead.

1 Like

Ok thanks good to know! Instead of using with, I should do something like this. Basically registering the hook in the constructor, using .stored in the forward method and then remove the hook in the del method?

class ParentChildModel(Module):
    def __init__(self, questions):
        self.results = TabularModel(results_emb_szs, len(results_cont_names), 10, [200, 50], ps=[0.01, .1], embed_p=0.04, bn_final=False)
		
        self.questions = questions
        
        self.hook_questions = hook_output(self.questions.layers[-3][2], detach=False)
        
        self.head = nn.Sequential(*[LinBnDrop(20, 1, p=0.1), SigmoidRange(*[-.1, 1.1])])

    def forward(self, data, children):
		[... some code here...]
		
		result = self.questions(cat, cont)
		mid = self.hook_questions.stored
		
	def __del__(self):
        self.hook_questions.remove()

Thanks!

I think so, yes.

1 Like