Understanding argument "detach" in Hook

Hello,

I’m having a little bit hard time understanding argument detach in Hook(). My current understanding is that with a default of True, the copied(hooked) version of input tensor is disconnected from the computational graph and thus gradients won’t be propagated through the skip connection. On the other hand, the original version of the input tensor is still in the computational graph so gradients could still be propagated through it without a problem… Am I missing something here?

class Hook():
"Create a hook on `m` with `hook_func`."
def __init__(self, m:nn.Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
    self.hook_func,self.detach,self.stored = hook_func,detach,None
    f = m.register_forward_hook if is_forward else m.register_backward_hook
    self.hook = f(self.hook_fn)
    self.removed = False

def hook_fn(self, module:nn.Module, input:Tensors, output:Tensors):
    "Applies `hook_func` to `module`, `input`, `output`."
    if self.detach:
        input  = (o.detach() for o in input ) if is_listy(input ) else input.detach()
        output = (o.detach() for o in output) if is_listy(output) else output.detach()
    self.stored = self.hook_func(module, input, output)

Thanks,

2 Likes

That’s exactly it.

2 Likes

Cool! Thanks :slight_smile: