Can I use a hook inside the forward function of a module?

Can we use the result of a hook in the forward of a module?

For example I have a Module where I pass in a pre-trained model named questions. questions is a TabularModel, but I want the result of a Linear layer inside of it to do more computations with it:

class ParentChildModel(Module):
    def __init__(self, questions):
        self.results = TabularModel(results_emb_szs, len(results_cont_names), 10, [200, 50], ps=[0.01, .1], embed_p=0.04, bn_final=False)
        self.questions = questions
        
        self.head = nn.Sequential(*[LinBnDrop(20, 1, p=0.1), SigmoidRange(*[-.1, 1.1])])

    def forward(self, data, children):
            [... some code here...]
            # self.questions.layers[-3][2] is a Linear layer
            with hook_output(self.questions.layers[-3][2], detach=False) as h:
                result = self.questions(cat, cont)
                mid = h.stored

Then I use mid variable in some other calculations later. I am not getting really good results with that so I was wondering if this is valid and if the gradients will flow well.

The hook is detached as soon as you leave the context manager, so Iā€™m not sure you should use one. Use the way DynamicUnet is implemented instead.

1 Like

Ok thanks good to know! Instead of using with, I should do something like this. Basically registering the hook in the constructor, using .stored in the forward method and then remove the hook in the del method?

class ParentChildModel(Module):
    def __init__(self, questions):
        self.results = TabularModel(results_emb_szs, len(results_cont_names), 10, [200, 50], ps=[0.01, .1], embed_p=0.04, bn_final=False)
		
        self.questions = questions
        
        self.hook_questions = hook_output(self.questions.layers[-3][2], detach=False)
        
        self.head = nn.Sequential(*[LinBnDrop(20, 1, p=0.1), SigmoidRange(*[-.1, 1.1])])

    def forward(self, data, children):
		[... some code here...]
		
		result = self.questions(cat, cont)
		mid = self.hook_questions.stored
		
	def __del__(self):
        self.hook_questions.remove()

Thanks!

I think so, yes.

1 Like