Callback discussion from lesson 9

I am creating a separate topic to discuss the callback. Here is the link to what callbacks are that was posted in lesson 9, to get us started.

This is a wiki post, so please edit it to add any use resources you find.


Thanks for opening this subject - for me it was by far the most difficult part of the lecture, and I’ll probably have some questions once I read through the materials.

It’s also good to mention here the fastai docs about callbacks.


Ok, I’m reading through the callback docs and I just don’t understand what is happening here:

So there is supposedly a callback called LRFinder but then they never actually show that callback. My assumption is that it exists in learn.lr_find(). But would it make more sense to show this code more explicitly if this section is about callbacks?

Just seeing if I’m understanding this correctly at all or if I am totally missing the callback in this section.

The docs there say that you generally want to use this callback via the lr_find method. Take a look at the source of that method and tell use what you find! :slight_smile:

We often provide convenience methods like this, so you don’t have to actually instantiate the callback yourself most of the time. If you’re interested in knowing what the callback actually implements, click the “LRFinder” link in the page you showed, and you’ll see this info:

I was completely lost in the second half of the lecture but after asking Jermey, I don’t fee too bad. Apparently I should give myself a good year to understand :slight_smile:

I made the top post a wiki so we can use it to track useful resources.

I just did a quick google search for callback tutorials, but the only decent ones I found were in JavaScript. Hopefully others have better google-fu than me!

So diving into the github source code now I’m finding the lr_finder callback here.

class LRFinder(LearnerCallback):
    "Causes `learn` to go on a mock training from `start_lr` to `end_lr` for `num_it` iterations."
    def __init__(self, learn:Learner, start_lr:float=1e-7, end_lr:float=10, num_it:int=100, stop_div:bool=True):
        super().__init__(learn),self.stop_div =,stop_div
        self.sched = Scheduler((start_lr, end_lr), num_it, annealing_exp)
        #To avoid validating if the train_dl has less than num_it batches, we put aside the valid_dl and remove it
        #during the call to fit.
        self.valid_dl = = None

    def on_train_begin(self, pbar, **kwargs:Any)->None:
        "Initialize optimizer and learner hyperparameters."
        setattr(pbar, 'clean_on_interrupt', True)'tmp')
        self.opt = self.learn.opt = self.sched.start
        self.stop,self.best_loss = False,0.

    def on_batch_end(self, iteration:int, smooth_loss:TensorOrNumber, **kwargs:Any)->None:
        "Determine if loss has runaway and we should stop."
        if iteration==0 or smooth_loss < self.best_loss: self.best_loss = smooth_loss = self.sched.step()
        if self.sched.is_done or (self.stop_div and (smooth_loss > 4*self.best_loss or torch.isnan(smooth_loss))):
            #We use the smoothed loss to decide on the stopping since it's less shaky.
            return {'stop_epoch': True, 'stop_training': True}

    def on_train_end(self, **kwargs:Any)->None:
        "Cleanup learn model weights disturbed during LRFind exploration."
        # restore the valid_dl we turned off on `__init__` = self.valid_dl
        self.learn.load('tmp', purge=False)
        if hasattr(self.learn.model, 'reset'): self.learn.model.reset()
        print('LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.')

So then I see that it inherits (correct term?) from LearnerCallback which is here

class LearnerCallback(Callback):
    "Base class for creating callbacks for a `Learner`."
    def __init__(self, learn):
        self._learn = weakref.ref(learn)
        self.exclude,self.not_min = ['_learn'],[]
        setattr(self.learn, self.cb_name, self)

    def __getattr__(self,k): return getattr(self.learn, k)
    def __setstate__(self,data:Any): self.__dict__.update(data)

    def learn(self) -> Learner: return self._learn()
    def learn(self, learn: Learner) -> None: self._learn = weakref.ref(learn)

    def cb_name(self): return camel2snake(self.__class__.__name__)

This also inherits from the standard Callback class here

class Callback():
    "Base class for callbacks that want to record values, dynamically change learner params, etc."
    def on_train_begin(self, **kwargs:Any)->None:
        "To initialize constants in the callback."
    def on_epoch_begin(self, **kwargs:Any)->None:
        "At the beginning of each epoch."
    def on_batch_begin(self, **kwargs:Any)->None:
        "Set HP before the output and loss are computed."
    def on_loss_begin(self, **kwargs:Any)->None:
        "Called after forward pass but before loss has been computed."
    def on_backward_begin(self, **kwargs:Any)->None:
        "Called after the forward pass and the loss has been computed, but before backprop."
    def on_backward_end(self, **kwargs:Any)->None:
        "Called after backprop but before optimizer step. Useful for true weight decay in AdamW."
    def on_step_end(self, **kwargs:Any)->None:
        "Called after the step of the optimizer but before the gradients are zeroed."
    def on_batch_end(self, **kwargs:Any)->None:
        "Called at the end of the batch."
    def on_epoch_end(self, **kwargs:Any)->None:
        "Called at the end of an epoch."
    def on_train_end(self, **kwargs:Any)->None:
        "Useful for cleaning up things and saving files/models."
    def jump_to_epoch(self, epoch)->None:
        "To resume training at `epoch` directly."

    def get_state(self, minimal:bool=True):
        "Return the inner state of the `Callback`, `minimal` or not."
        to_remove = ['exclude', 'not_min'] + getattr(self, 'exclude', []).copy()
        if minimal: to_remove += getattr(self, 'not_min', []).copy()
        return {k:v for k,v in self.__dict__.items() if k not in to_remove}

    def  __repr__(self):
        attrs = func_args(self.__init__)
        to_remove = getattr(self, 'exclude', [])
        list_repr = [self.__class__.__name__] + [f'{k}: {getattr(self, k)}' for k in attrs if k != 'self' and k not in to_remove]
        return '\n'.join(list_repr)

This finally looks familiar. This is what Jeremy showed us during lesson 9. So it is a few layers of inheritance away, but it uses this as its base.

So what I am now thinking I sort of slightly understand is that everything in the base callback class is checked automatically in FastAI’s callback system. So the only thing you need to understand from all of this is where each of these are located and it tells you in the Callback class comments:

Modified to be more compact (taken from the callback class shown above)
on_train_begin - To initialize constants in the callback.
on_epoch_begin - At the beginning of each epoch.
on_batch_begin - Set HP before the output and loss are computed.
on_loss_begin - Called after forward pass but before loss has been computed.
on_backward_begin - Called after the forward pass and the loss has been computed, but before backprop.
on_backward_end - Called after backprop but before optimizer step. Useful for true weight decay in AdamW.
on_step_end - Called after the step of the optimizer but before the gradients are zeroed.
on_batch_end - Called at the end of the batch.
on_epoch_end - Called at the end of an epoch.
on_train_end - Useful for cleaning up things and saving files/models.
jump_to_epoch - To resume training at `epoch` directly.

Right - or just read the fit() and loss_batch() source code, which is what I do when implementing a callback. (Well actually I just normally copy another callback that does something vaguely similar, and modify it.)

And to complete the picture, here’s the lr_find method, which is what you’d generally use in practice:


Ok, I am finally understanding (slightly). These are places that you can automatically inject code into the training process that will help make the small modifications to actually make the process very flexible. So if a paper says that you need to have a variable that increases the learning rate by 10% after every epoch, you can just create a new class and make a new def for on_epoch_end that takes the learning rate and increases it by 10%.


hi, you should look at * Sylvain’s talk, An Infinitely Customizable Training Loop (from the NYC PyTorch meetup) and the slides that go with it.

Callbacks**, on their own, are quite easy to get if you have some coding background. What I think it is worth noting here is why we have to use them. The talk above tell us that. Unfortunately, the audio quality is really bad :frowning:

** : meant as general concept in computer programming.


Right - spend some time looking at the ParamScheduler code, and step thru it in a debugger, to see this happening.

I remember when I first learned about callbacks it took me a long time to understand. I’m not sure I’d call them “quite easy” when you first come across them. The idea of passing a function as a parameter can be quite confusing until you get the hang of it.


Might be a bit unrelated here.
To get a more broader understanding of patterns i will start here –

These are the most common programming/design patterns which occur repeatedly across most of the programming languages. If we go through this and understand at high level how most patterns work, most of the code will be very intuitive to read and to understand.

Here is an idea for a minor additional functionality for the Runner: give it the option to do a dry run (without training) and only output the names of the callbacks in the order they would be run after each type of event that would trigger them. This can be useful for testing simple dependencies and debugging complex training workflows. (Maybe it’s more appropriate for the Marathon Runner, rather than the typical Runner).


Can someone edit in a link to the latest Runner class. Or is the latest the one in the lesson 9 notebook?

1 Like

I wouldn’t have been impolite or nasty towards anybody. What I meant was if we can abstract away from implementation details the concept of cb is not very hard: only a new piece inside the execution flow of the program. But from here to mastering it, I agree, it is a long way.

I liked your “if you do something a lots of times make it small”. It is a nice extension/analogy of the concept of entropy in information theory :wink:

There was a time in my life I invested quite heavily in studying design patterns - and personally I found that time entirely unhelpful and actually rather confusing.

I’m sure some people find them useful. But just mentioning my experience so folks don’t feel like this is a topic that everyone must master! :slight_smile:


When I first read the design patterns book I obsessed over it,I think it should mostly be used as a quick read to get an idea on various ways to abstract concepts, once you have coded for awhile I think design patterns will be learned as you need them.

The real reason to read about design patterns is to have a quick and easy way to describe coding concepts, pass interviews, and make up interview questions to stump interviewees.

I find making a project with automated tools in place to check code one of the most important things. Code Climate-like tools, linting(if decent implementation), hinting, automated tests, and working with other people are the best ways to help.

If you can’t easily write an automated test for your code, it probably is over-complicated. Hinting gives hints on how to improve code. Linting helps you become consistent, and notice the project your working on’s style.

Idk, I feel like there are so many automated tools in the coding world that will get you off the ground, I think it is better to just install them and learn by them correcting you.


I had requested my friends to help me understanding callbacks better
@suvash was kind enough to share an explaination that I found really helpful
(Shared here word to word before seeking permission)

They’re implemented in various ways in different languages/frameworks, but the core idea is usually the same.

Say, there is a some code (X) that runs, and has a bunch of ‘stages’ that it goes through, the diff. stages usually being some sort of internal state update behind the scenes, that the code (X) doesn’t usually expose (because it’s internal stuff etc.).

Now, it’s usually the case that whenever these stage changes happens (in X), one would like to run some ad-hoc(read: any random) piece of another code (Y). To be able to do this, the code (X) is initially called with a reference to code (Y) (for eg. using classes/objects/functions etc. as an argument). To tie these two things together, the function/method call is agreed upon, i.e. code in Y implements the function/method that will later be called by X.

Now, whenever the code (in X) is going through these stages, it can execute the code in Y, given the names match up. This in fact, now looks like the code in X is “calling back” the code in Y, that you passed earlier when starting X.

Why make it this complicated, you might ask ? This idea lets you hook into the state changes of X and run ad-hoc code Y, without knowing/rewriting the internals of X, thus decoupling them.

In the end of the day, the ‘callback’ idea makes it possible to call back any code Y, that you initially provide to X before it starts running, without having to know/rewrite X. Very Good Idea when building extensible libraries etc., state machines that let you run something when the state changes. Each language/library/framework will add their own twist to it, but this is the core idea really.

Let’s hope this confusion of X and Y didn’t make it messier for you.