I find it easiest to understand something when you try to implement it. For example, I wrote this Callback to assist in implementing the Stanford ML Group’s MRNet (see this thread) via the fastai library.
TL;DR – MRNet is implemented by stacking a series of images in a single batch, then squeezing the “batch” to conform to the expected shape for input into the neural net.
So, I decided to implement a Callback that does this on_begin_batch. It took digging through the fastai library – reviewing code for the Callback-related classes and looking at examples – to understand how a Callback is implemented via CallbackHandler.
In fastai==1.0.50, the CallbackHandler maintains a state_dict that facilitates manipulation of variables within the training loop. I found it helpful to instantiate CallbackHandler in my notebook to see how state_dict evolved after a call to cbh.on_batch_begin() in order to understand how to implement the change I wanted to make.
The end result was this:
class MRNetCallback(Callback):
def on_batch_begin(self, last_input, **kwargs):
x = torch.squeeze(last_input, dim=0)
return dict(last_input=x)
I haven’t yet tried to implement this via the new Callbacks system covered in recent lessons, but I imagine things will change slightly and perhaps be even easier to understand/implement.