New optimizer, training loop and callbacks

To continue the documentation of our development of the library, here is where we’re at in terms of optimizer, training loop and callbacks (all in the notebook 004).

Optimizer and easy hyper-parameters access

With a very good idea that I stole from @mcskinner, the optimizer is now wrapped inside a class called HPOptimizer. The four hyper-parameters mostly used are properties of this HPOptimizer named lr, mom, wd and beta, and each has its custom setter that will actually put the value in the right place of the internal optimizer param dictionary.

  • lr stands for learning rate and is the same in every optimizer
  • wd stands for weight decay and is also the same in every optimizer. It does L2 regularization and not true weight decay (see here for more details on the difference).
  • mom stands for momentum, which is the momentum in SGD and RMSProp, and the first element in the betas tuple in Adam.
  • beta is the second elements in the betas tuple of Adam, or the alpha in RMSProp

For the end user, you create this HPOtimizer by passing model parameters, an opt_fn (like in fastai) and a initial lr like this:

opt = HPOptimizer(model.parameters(), opt_fn, init_lr)

Then you can access or change any of those hyper-parameters by typing = 1e-2, which is far more convenient than in fastai.

This doesn’t support differential learning rates yet, but that will be an easy change once we have decided how to represent the different groups of layer, we’ll just have to adapt the setter for lr (and wd if we want) to handle it.

The training loop

In the current fastai library, the training loop has progressively become a huge mess. This is because each time someone has added something new to the library that affects training, they added a few more lines (or extra arguments) to the training loop. So when a beginner now looks at what is the core of training, he can’t see the important steps. And on the other side, each of those additions (like the training API, half-precision training, swa or stuff specific to the LMs) have bits of code in several places which also make them difficult to understand.

For fastai_v1, we have decided to be more pedagogic, and the training loop won’t move from its version in notebook 004:

def fit(epochs, model, loss_fn, opt, data, callbacks=None, metrics=None):
    cb_handler = CallbackHandler(callbacks)
    for epoch in tnrange(epochs):
        for xb,yb in data.train_dl:
            xb, yb = cb_handler.on_batch_begin(xb, yb)
            loss,_ = loss_batch(model, xb, yb, loss_fn, opt, cb_handler)
            if cb_handler.on_batch_end(loss): break
        if hasattr(data,'valid_dl') and data.valid_dl is not None:
            with torch.no_grad():
                *val_metrics,nums = zip(*[loss_batch(model, xb, yb, loss_fn, metrics=metrics)
                                for xb,yb in data.valid_dl])
            val_metrics = [np.sum(np.multiply(val,nums)) / np.sum(nums) for val in val_metrics]
        else: val_metrics=None
        if cb_handler.on_epoch_end(val_metrics): break

def loss_batch(model, xb, yb, loss_fn, opt=None, cb_handler=None, metrics=None):
    out = model(xb)
    if cb_handler is not None: out = cb_handler.on_loss_begin(out)
    loss = loss_fn(out, yb)
    mets = [f(out,yb).item() for f in metrics] if metrics is not None else []
    if opt is not None:
        if cb_handler is not None: loss = cb_handler.on_backward_begin(loss)
        if cb_handler is not None: cb_handler.on_backward_end()
        if cb_handler is not None: cb_handler.on_step_end()
    return (loss.item(),) + tuple(mets) + (len(xb),)

If you skip all those annoying calls to this cb_handler object, you have the simple training loop:

  1. range over epochs
  2. go through the training batches xb, yb
  3. call the model on xb
  4. compute the loss between the output and the target yb
  5. compute the gradients (loss.backward())
  6. do the optimizer step
  7. zero the gradient for the next step
  8. (optional) compute the loss on the validation set and some metrics

So why add all those lines with cb_handler? Well since we don’t want to change the training loop to leave it clear and simple, we have to code all the things that affect this training loop or use information inside it elsewhere. That’s why there is a callback function between every line of the fit function: to code what should go between those lines in a different object.


The guideline to add a new functionality that goes with training in the fastai_v1 library will be to do it entirely in a Callback. This will also help clarity since the code for a given piece to add will all be in one place. To do that, callbacks have access to everything that is happening inside the training loop, can interrupt an epoch or the training at any time and can modify everything (even the data or the optimizer). Just look at the example called EyeOfSauron to see how it’s possible.

As examples for now, you can see how the LRFinder is completely implemented inside a callback (just missing the save the model at the beginning and the load at the end but that’s an easy add once those functions exist), or a 1cycle schedule.


Do you think we’ll be able to incorporate true weight decay into the HPOptimizer directly, so that stuff like wd annealing works with either Adam or AdamW without dupe code? Or do you have some other plan for AdamW?

@jeremy you could probably put a toggle in the HPOptimizer constructor, and then do a simple dispatch under the hood in the getters and setters.

For both of you: what your your thoughts on preventing bad interactions? Putting new features into callbacks instead of the training loop itself makes a lot of sense for modularity, but it also hides a lot of potential bugs in order-dependent interactions.

Those are more likely now that the callbacks can manipulate so much of the training loop, e.g. trying to do multiple things to xb and yb for each batch. Ordering becomes important, and might not even be the same across callback hooks. If we modify the data twice for on_batch_begin it seems plausible that we’d want to modify the output with on_loss_begin in the reverse order to unwind those transforms before loss calculations.

1 Like

Yes, the weight decay update can be done inside the optimizer, since it has the models parameters. It’s also probably best since when mixed-precision comes along, doing the weight decay step is probably something we want in FP32.

I’m not too concerned about bad interactions as there are very few callbacks that will actually need to change the input/output/loss. The only two I can think of are half-precision training and regularization functions for RNNs. In any case we will be very careful when accepting PRs with new callbacks that make changes like this, and yes it’s obvious order will matter, especially with the half-precision training callback coming along.

Remember the end user won’t see the callbacks too much though, and the functions of the fastai library will order them for him in the order he needs it (the only situation I see this become a problem is half-precision training mixed with RNNs, which is something that isn’t working properly in fastai now anyway).

1 Like

Interesting, you just read my thoughts…almost.
My more general solution for pytorch train loop:
and simple example:

Meanwhile, works quite nice in my everyday tasks.

1 Like

For the on_epoch_begin and on_epoch_end methods, what do you think about passing the epoch number as the first argument? Sure, you can implement something in the callback itself to increment every time one of those methods is called, but if you have a couple callbacks that want that functionality, and you already have the epoch variable in the loop, it might be helpful to just pass it in. The Keras callback API does it that way.


From my experience, the best solution is to pass frozen (no new keys during training loop) key-value storage to callback with all necessary statistics.

Good point, as is @scitator response too. Will play along with the CallbackHandler keeping tracks of all statistics and giving them as **kwargs to the callbacks, that way any of the callbacks will be able to unpack the bits needed (epoch, batch number, iteration number, last loss, last output, last input, last target…) whenever required (and not just necessarily at on_epoch_begin for instance, if we’re talking about the epoch number).


What about unit tests? Are you considering it?

Pushed a new version that allows you to access anything stored in the CallbackHandler state_dict, as long as you know its name. Hope that helps!

I love the callbacks and I don’t know if any of these ideas are any good but I’ll throw them out there. My apologies if these are “out of left field”.

So one thing I didn’t like about the prior training loop was that when it got to the validation phase the progress bar would disappear and I was in the dark about the progress on those batches. This was probably a bigger deal for me since my computer isn’t the fastest. Anyway, it still seems like it would be nice to know the progress of the validation phase.

Also, it seems to me, and I could be wrong, that going through the entire validation phase, if the accuracy or whatever metric you are using isn’t close to what you are trying to obtain, is a waste of time. So, it seems like it would be nice to be able stop the validation phase prematurely just like we now can do with the training phase. For example, if you are trying to obtain 98% or higher accuracy but after a few validation batches you are at 88% accuracy, you probably aren’t going to get to 98% during this validation phase and time would be better spent going on to the next training epoch.

Lastly, in the spirit of keeping the training loop as simple as possible, it doesn’t seem correct to add the metrics like accuracy into the training loop but should be handled by callbacks just like everything else.

Here are my suggested changes:

  1. Move the on_batch_begin and on_batch_end calls to the loss_batch method so that they are called for both the training and validation phases.
  2. Allow the validation phase to be stopped just like the training phase via the return value of on_batch_end.
  3. Move the metrics to their own callback handlers so that they can calculate metrics and can also prematurely stop the validation phase if desired. Validation loss would stay in the fit function.
  4. Add an on_valid_begin to signal the beginning of the validation phase and a batch_type variable to the state_dict which is either ‘train’ or ‘valid’ to let callbacks know if the batch is training or validation batch.
  5. Perhaps, add a MetricCallback class which extends Callback and provides a name() and metric() methods which would be implemented by all metric callbacks. Then, there could be a callback defined at the end of all callbacks which could gather all metric names and values for diplay/reporting purposes.

I’ve created a sample update with these proposed suggestion to the 004 notebook which can be found at [], however I haven’t been able to run these changes as I’m having difficulty upgrading to Python 3.7 and it seems that their is dependency on dataclasses which seems to be a 3.7 addition, so excuse me if the Python code isn’t 100% accurate. I’m still learning my way around Python (I’m a Java/C++/Swift dev.)

Here are also some screenshots of the proposed ideas.

Hi Stephen, thanks a lot for all those ideas! Even if I’m going to argue against most of them, we share the developing process with the community to get your advice and see how to make things better, so please keep them coming.

  1. For the progress bar during the validation phase, Jeremy beat you to it already. The validation bar is now integrated in the dataloader, and every dataloader has its own that comes out automatically.

  2. I’m not really keen on interrupting the validation phase: if you do that, you won’t have an exact validation loss, and it’s really key to figure out when you started overfitting to adapt your next choice of hyperparameters. One thing that is doable is to skip validation for the first few epochs (the same way the Learner callback currently skips validation, by suppressing the validation dataloader).

  3. For the metrics inside the training loop, I debated a lot. The one thing that made me decide to put it inside the training loop is that it’s passed to the callbacks (like early stopping). So if we delegate it inside a callback, we’ll need to have other callbacks either subclass it or communicate with it, which seems a bit messy. Since it only adds a few lines, and is part of core training, I put it there.

  4. Having the on_batch_begin and on_batch_end inside of loss_batch is also something I tried. Like you noticed, it requires to have an on_valid_begin to switch everything in validation mode (the switch back being in on_epoch_begin). The things I didn’t like with that is that pretty much every Callback has an on_batch_begin/on_batch_end method, and it required an additional test in all of them, since 99% of the stuff we do is only for training. We can still have that though, but in this case, I’d keep the switch train/valid inside the CallbackHandler who would call on_valid_batch_begin and on_valid_batch_end if we are inside the validation loop. Not sure those are entirely necessary though.

Last thing! You can have dataclasses in python 3.6, there is a library for it (called dataclasses).

Hope that explains a bit my choices. @jeremy I would love to have your thoughts on this too.


I figured it most likely these had all been considered but it’s great to know all the reasoning and how they are being done in other ways like the validation progress. Thanks for the detailed explanation!

Yup we’re looking to have good test coverage.



I played around a little bit with the new Callbacks and holoviews. So here’s a Callback for live updating curves about lr, losses, momentum and val_losses.

If interested:

# for jupyterlab: you have to install the bokeh and holoviews extensions first
import holoviews as hv

def plot(stats):
    return hv.Layout(
        [(hv.Curve(stat, label=name).redim.label(x='Iterations', y=name)).options(
framewise=True, shared_axes=False) for name, stat in stats.items()]).cols(2)

class HVLiveUpdate(Callback):
    def __init__(self, learn: Learner):
        self.learn = learn
    def on_train_begin(self, **kwargs):
        self.stats_stream = hv.streams.Stream.define(
            'stats', stats={'lrs':[], 'losses':[], 'val_losses':[], 'moms':[]})()
        self.dmap = hv.DynamicMap(plot, streams=[self.stats_stream])
    def on_batch_end(self, last_loss, **kwargs):
        new_stats = {

I really like that new callback api, but now I have some thoughts/questions:

  1. I don’t really understand why we need a “epoch loop”. Imo this is just less flexible than it could be. For example: What about live data coming in or if one epoch is just too big? Epochs don’t make sense then. Why not give every callback a frequency parameter? (like calc. every n batches)

  2. I really like the idea of an intelligent Dataloader/Sampler: If we store some information about which samples were classified wrongly, we could train more frequently on the tough ones. (like AdaBoost) That would probably save time and increase performance. (Maybe it also helps with unbalanced classes?) (Again epoch doesn’t make to much sense here imo.)

  3. It would be nice if there would be an easy way of doing stacking.

So just some thoughts I’d like to hear some opinions.

Here are my thoughts:

  1. If you don’t like the epoch loop, you can forget it and just pass 1 as the epoch parameter :wink: Inside a callback, you can pop the parameter iteration from the kwarg to do something every n batches.

  2. This is already possible if you passe the learner object to the callbcak you write. It will then have the dataloaders so you have a pipeline to communicate the samples classified wrongly. Careful though, as it seems the technique you describe would lead to massive overfitting (not a big deal for decision trees but more annoying in deep learning).

  3. Stacking isn’t really linked to the training loop, it’s more in the design of the model.

Thanks for your fast answer. So the picture I have in mind for a sampler is like this:
(from wiki)

The higher the box number the lower the probability the image gets loaded. If you get additional training data, put it in one of the boxes.
Sure it could overfit, so one should be careful with prob. of “box 5”

Anyway it is nice that this is probably doable with what you’ve written.
So thanks again.

@sgugger, not sure if you are the right person to inform but in running the 004_callbacks notebook the CallbackHandler has the number of epochs as “n_epochs” for on_train_begin, but the OneCycleScheduler has the parameter as “epochs”, resulting in

TypeError: on_train_begin() missing 1 required positional argument: 'epochs'

Oopsie, forgot to change that there. Thanks for telling me!

1 Like