Hi,
I have been trying to use the ActivationStats callback for text classification. I’m facing a bit of trouble.
When I try:
learn = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5,callback_fns=ActivationStats)`
learn.fit_one_cycle(1, 2e-2, moms=(0.8,0.7))
I get the following error:
/usr/local/lib/python3.6/dist-packages/fastai/torch_core.py in tensor(x, *rest)
79 # XXX: Pytorch bug in dataloader using num_workers>0; TODO: create repro and report
80 if is_listy(x) and len(x)==0: return tensor(0)
—> 81 res = torch.tensor(x) if is_listy(x) else as_tensor(x)
82 if res.dtype is torch.int32:
83 warn(‘Tensor is int32: upgrading to int64; for better performance use int64 input’)RuntimeError: Could not infer dtype of NoneType
I’ve tried debugging. The problem seems to be self.stats
in ActivationStats have None
values. The classes for hooks are as follows:
class Hook():
"Create a hook on `m` with `hook_func`."
def __init__(self, m:nn.Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
self.hook_func,self.detach,self.stored = hook_func,detach,None
f = m.register_forward_hook if is_forward else m.register_backward_hook
self.hook = f(self.hook_fn)
self.removed = False
def hook_fn(self, module:nn.Module, input:Tensors, output:Tensors):
"Applies `hook_func` to `module`, `input`, `output`."
if self.detach:
input = (o.detach() for o in input ) if is_listy(input ) else input.detach()
output = (o.detach() for o in output) if is_listy(output) else output.detach()
self.stored = self.hook_func(module, input, output)
def remove(self):
"Remove the hook from the model."
if not self.removed:
self.hook.remove()
self.removed=True
def __enter__(self, *args): return self
def __exit__(self, *args): self.remove()
class Hooks():
"Create several hooks on the modules in `ms` with `hook_func`."
def __init__(self, ms:Collection[nn.Module], hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]
def __getitem__(self,i:int)->Hook: return self.hooks[i]
def __len__(self)->int: return len(self.hooks)
def __iter__(self): return iter(self.hooks)
@property
def stored(self): return [o.stored for o in self]
def remove(self):
"Remove the hooks from the model."
for h in self.hooks: h.remove()
def __enter__(self, *args): return self
def __exit__ (self, *args): self.remove()
The problem, in my opinion, is that the Hook(s) classes expect the input and output to be Tensors. However, LSTM in PyTorch have their input and output in the following format:
Inputs: input, (h_0, c_0)
Outputs: output, (h_n, c_n)
I have thus tried removing the type enforcement in Hook, Hooks and ActivationStats and yet I’m still not able to get it to work. I’ve made the following changes:
class Hook():
"Create a hook on `m` with `hook_func`."
# Removed type enforcement
def __init__(self, m, hook_func, is_forward=True, detach=True):
self.hook_func,self.detach,self.stored = hook_func,detach,None
f = m.register_forward_hook if is_forward else m.register_backward_hook
self.hook = f(self.hook_fn)
self.removed = False
# Removed type enforcement
def hook_fn(self, module, input, output):
"Applies `hook_func` to `module`, `input`, `output`."
import pdb;pdb.set_trace()
if self.detach:
input = (o.detach() for o in input ) if is_listy(input ) else input.detach()
output = (o.detach() for o in output) if is_listy(output) else output.detach()
self.stored = self.hook_func(module, input, output)
def remove(self):
"Remove the hook from the model."
if not self.removed:
self.hook.remove()
self.removed=True
def __enter__(self, *args): return self
def __exit__(self, *args): self.remove()
class Hooks():
"Create several hooks on the modules in `ms` with `hook_func`."
# Removed type enforcement
def __init__(self, ms, hook_func, is_forward=True, detach=True):
self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]
def __getitem__(self,i:int)->Hook: return self.hooks[i]
def __len__(self)->int: return len(self.hooks)
def __iter__(self): return iter(self.hooks)
@property
def stored(self): return [o.stored for o in self]
def remove(self):
"Remove the hooks from the model."
for h in self.hooks: h.remove()
def __enter__(self, *args): return self
def __exit__ (self, *args): self.remove()
class HookCallback(LearnerCallback):
"Callback that can be used to register hooks on `modules`. Implement the corresponding function in `self.hook`."
def __init__(self, learn:Learner, modules:Sequence[nn.Module]=None, do_remove:bool=True):
super().__init__(learn)
self.modules,self.do_remove = modules,do_remove
def on_train_begin(self, **kwargs):
"Register the `Hooks` on `self.modules`."
if not self.modules:
self.modules = [m for m in flatten_model(self.learn.model)
if isinstance(m, torch.nn.LSTM)]
# Fetching only LSTMs
self.hooks = Hooks(self.modules, self.hook)
def on_train_end(self, **kwargs):
"Remove the `Hooks`."
if self.do_remove: self.remove()
def remove(self):
if getattr(self, 'hooks', None): self.hooks.remove()
def __del__(self): self.remove()
class ActivationStats_LSTM(HookCallback):
"Callback that record the mean and std of activations."
def on_train_begin(self, **kwargs):
"Initialize stats."
super().on_train_begin(**kwargs)
self.stats = []
# Removed type
def hook(self, m, i, o):
"Take the mean and std of `o`."
import pdb;pdb.set_trace()
return o[1][0].mean().item(),o[1][0].std().item()
def on_batch_end(self, train, **kwargs):
"Take the stored results and puts it in `self.stats`"
if train: self.stats.append(self.hooks.stored)
def on_train_end(self, **kwargs):
"Polish the final result."
super().on_train_end(**kwargs)
self.stats = tensor(self.stats).permute(2,1,0)
I’m still not able to hit the pdb I have added in the hook method.
I think a different ActivationStats is required for handling RNNs or a more abstract type needs to be added in the existing one.
Any thoughts would be appreciated.
I have created an issue on github tending to some aspect of this bug. I would like to take this up if it passes triage.