Trying to make ActivationsStats and Hooks work for text_classifier

(Vimarsh Chaturvedi) #1

Hi,

I have been trying to use the ActivationStats callback for text classification. I’m facing a bit of trouble.

When I try:

learn = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5,callback_fns=ActivationStats)`
learn.fit_one_cycle(1, 2e-2, moms=(0.8,0.7))

I get the following error:

/usr/local/lib/python3.6/dist-packages/fastai/torch_core.py in tensor(x, *rest)
79 # XXX: Pytorch bug in dataloader using num_workers>0; TODO: create repro and report
80 if is_listy(x) and len(x)==0: return tensor(0)
—> 81 res = torch.tensor(x) if is_listy(x) else as_tensor(x)
82 if res.dtype is torch.int32:
83 warn(‘Tensor is int32: upgrading to int64; for better performance use int64 input’)

RuntimeError: Could not infer dtype of NoneType

I’ve tried debugging. The problem seems to be self.stats in ActivationStats have None values. The classes for hooks are as follows:

class  Hook():
    "Create a hook on `m` with `hook_func`."
    def __init__(self, m:nn.Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
        self.hook_func,self.detach,self.stored = hook_func,detach,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module:nn.Module, input:Tensors, output:Tensors):
        "Applies `hook_func` to `module`, `input`, `output`."
        if self.detach:
            input  = (o.detach() for o in input ) if is_listy(input ) else input.detach()
            output = (o.detach() for o in output) if is_listy(output) else output.detach()
        self.stored = self.hook_func(module, input, output)

    def remove(self):
        "Remove the hook from the model."
        if not self.removed:
            self.hook.remove()
            self.removed=True

    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()

class Hooks():
    "Create several hooks on the modules in `ms` with `hook_func`."
    def __init__(self, ms:Collection[nn.Module], hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
        self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]

    def __getitem__(self,i:int)->Hook: return self.hooks[i]
    def __len__(self)->int: return len(self.hooks)
    def __iter__(self): return iter(self.hooks)
    @property
    def stored(self): return [o.stored for o in self]

    def remove(self):
        "Remove the hooks from the model."
        for h in self.hooks: h.remove()

    def __enter__(self, *args): return self
    def __exit__ (self, *args): self.remove()

The problem, in my opinion, is that the Hook(s) classes expect the input and output to be Tensors. However, LSTM in PyTorch have their input and output in the following format:
Inputs: input, (h_0, c_0)
Outputs: output, (h_n, c_n)

I have thus tried removing the type enforcement in Hook, Hooks and ActivationStats and yet I’m still not able to get it to work. I’ve made the following changes:

class Hook():
    "Create a hook on `m` with `hook_func`."
    # Removed type enforcement
    def __init__(self, m, hook_func, is_forward=True, detach=True):
        self.hook_func,self.detach,self.stored = hook_func,detach,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    # Removed type enforcement
    def hook_fn(self, module, input, output):
        "Applies `hook_func` to `module`, `input`, `output`."
        import pdb;pdb.set_trace()
        if self.detach:
            input  = (o.detach() for o in input ) if is_listy(input ) else input.detach()
            output = (o.detach() for o in output) if is_listy(output) else output.detach()

        self.stored = self.hook_func(module, input, output)

    def remove(self):
        "Remove the hook from the model."
        if not self.removed:
            self.hook.remove()
            self.removed=True

    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()


class Hooks():
    "Create several hooks on the modules in `ms` with `hook_func`."
    # Removed type enforcement
    def __init__(self, ms, hook_func, is_forward=True, detach=True):
        self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]

    def __getitem__(self,i:int)->Hook: return self.hooks[i]
    def __len__(self)->int: return len(self.hooks)
    def __iter__(self): return iter(self.hooks)
    @property
    def stored(self): return [o.stored for o in self]

    def remove(self):
        "Remove the hooks from the model."
        for h in self.hooks: h.remove()

    def __enter__(self, *args): return self
    def __exit__ (self, *args): self.remove()


class HookCallback(LearnerCallback):
    "Callback that can be used to register hooks on `modules`. Implement the corresponding function in `self.hook`."
    def __init__(self, learn:Learner, modules:Sequence[nn.Module]=None, do_remove:bool=True):
        super().__init__(learn)
        self.modules,self.do_remove = modules,do_remove

    def on_train_begin(self, **kwargs):
        "Register the `Hooks` on `self.modules`."
        if not self.modules:
            self.modules = [m for m in flatten_model(self.learn.model)
                            if isinstance(m, torch.nn.LSTM)]

        # Fetching only LSTMs
        self.hooks = Hooks(self.modules, self.hook)

    def on_train_end(self, **kwargs):
        "Remove the `Hooks`."
        if self.do_remove: self.remove()

    def remove(self): 
        if getattr(self, 'hooks', None): self.hooks.remove()
    def __del__(self): self.remove()


class ActivationStats_LSTM(HookCallback):
    "Callback that record the mean and std of activations."

    def on_train_begin(self, **kwargs):
        "Initialize stats."
        super().on_train_begin(**kwargs)
        self.stats = []

    # Removed type
    def hook(self, m, i, o):
        "Take the mean and std of `o`."
        import pdb;pdb.set_trace()
        return o[1][0].mean().item(),o[1][0].std().item()
    def on_batch_end(self, train, **kwargs):
        "Take the stored results and puts it in `self.stats`"
        if train: self.stats.append(self.hooks.stored)
    def on_train_end(self, **kwargs):
        "Polish the final result."
        super().on_train_end(**kwargs)
        self.stats = tensor(self.stats).permute(2,1,0)

I’m still not able to hit the pdb I have added in the hook method.
I think a different ActivationStats is required for handling RNNs or a more abstract type needs to be added in the existing one.
Any thoughts would be appreciated.
I have created an issue on github tending to some aspect of this bug. I would like to take this up if it passes triage.

1 Like