What is the v2 equivalent of "AccumulateScheduler"?

In v1, I subclassed AccumulateScheduler in order to implement gradient accumulation.

What would be the v2 approach to do the same?

Perhaps everything I need is available in the callbacks or gradient accumulation is already implemented somewhere in the framework??? Either way, definitely appreciate knowing which direction I should go in doing this with the v2 bits.

@wgpubs this was done in a very recent commit :wink: (like 12 hours ago).

Notebook: https://github.com/fastai/fastai2/blob/master/nbs/17_callback.training.ipynb


Sweet! I knew it was wise to ask first :slight_smile:



Wow … so much simpler in v2! Below, for posterity sake so folks know what this looked like in the old days, this is what I had to do:

class GradientAccumulation(AccumulateScheduler):
    _order = -40 # needs to run before the recorder
    def __init__(self, learn:Learner, n_step:int = 1, drop_last:bool = False):
        super().__init__(learn, n_step=n_step, drop_last=drop_last)
        self.acc_samples = 0
        self.acc_batches = 0
    def on_batch_begin(self, last_input, last_target, **kwargs):
        "accumulate samples and batches"

        self.acc_samples += last_input[0].shape[0]
        self.acc_batches += 1
    def on_backward_end(self, **kwargs):
        "accumulated step and reset samples, True will result in no stepping"
        if (self.acc_batches % self.n_step) == 0:
            for p in (self.learn.model.parameters()):
                # wtg - not all params have a gradient here, so check for p.grad != None
                if p.requires_grad and p.grad is not None: 
            self.acc_samples = 0
            return {'skip_step':True, 'skip_zero':True}
    def on_epoch_end(self, **kwargs):
        "step the rest of the accumulated grads if not perfectly divisible"
        for p in (self.learn.model.parameters()):
            # wtg - not all params have a gradient here, so check for p.grad != None
            if p.requires_grad and p.grad is not None: 
        if not self.drop_last: self.learn.opt.step()

Let me know how it works for you.
I just pushed it recently but I’m having a few issues on my project. There may be some incompatibility with fit_one_cycle and its scheduler but it could also be an issue with my own project.


  • the callback in v2 is currently defined by number of samples needed (not number of batches)
  • you will need to adjust your learning rate accordingly, ie if gradients are accumulated for 10 steps then you may want to divide your base learning rate (no accumulation) by 10

@sgugger I did the test you suggested previously.

  • Baseline: no gradient accumulation
  • Gradient Accumulation: batch size divided by 10, update every 10 batches, learning rate divided by 10
  • Training loop: regular fit (no fit_one_cycle for now to avoid potential issues with scheduler)
  • 5 runs for baseline + 5 runs for gradient accumulation (we show mean and min/max for each group)

Results are consistently worse in the gradient accumulation variant (blue are the runs with and orange is the baseline).

As you can see, learning rate is fixed and has been divided by 10 while other optimizer parameters are the same.

When looking at gradients, they are about 10 times higher with GradientAccumulation (as expected) and weight parameters remain pretty much the same.

You can see the full results comparison here.

Do you have any idea of an other parameter I should adjust when doing gradient accumulation?
I was thinking it could be due to the weight decay but it is not called until all gradients have been accumulated…

Note: I’m going to propose a PR to WandbCallback which I modified so it now logs automatically some config parameters to help me make these graphs.

Worked fine. Thanks much!

Just saw your post below … going to spend more time looking at your results. My use case was to fine tune a abstract summarization model where gradient accumulation was used by the paper’s authors (and not using one cycle). Will probably give the good ol’ fastai way a go to see how well it works itself.

1 Like