Building dropout scheduler callback - but 'begin_epoch' is always called 2x?

I’m finishing up a dropout scheduler to implement the curriculum learning for dropout paper.

However, using the Callback infrastructure/class from the notebooks (from exp.nb_09 import *), the event “begin_epoch” is always called 2x?

None of the other events seem to have this double call…anyway, I have to dynamically monitor the entire run (epoch_counts * batch size per epoch) to smoothly implement the curriculum dropout but this calling of begin_epoch 2x seems to be a bug?

I verified it’s happening not just in my callback but also the AverageStats callback implemented in the notebook:

Example:

begin_fit
self.n_epochs 0.0
538 batch size****
breakout of sets: warmup 53 middle 379 final 106
begin epoch - avg stats
begin epoch - dp sched
1 current epoch - dp sched
begin epoch - avg stats
begin epoch - dp sched
2 current epoch - dp sched

You can see that both “begin_epochs” are in fact called 2x.

Any clarification on what’s happening here would be appreciated!

The bug is in notebook 9b - the enum for callbacks has ‘begin_epoch’ listed 2x:

ALL_CBS = {‘begin_batch’, ‘after_pred’, ‘after_loss’, ‘after_backward’, ‘after_step’,
‘after_cancel_batch’, ‘after_batch’, ‘after_cancel_epoch’, ‘begin_fit’,
’begin_epoch’, ‘begin_epoch’, ‘begin_validate’, ‘after_epoch’,
‘after_cancel_train’, ‘after_fit’}

actually that’s not it - the double call is from this last line:

def do_begin_epoch(self, epoch):
    self.epoch,self.dl = epoch,self.data.train_dl
    return self('begin_epoch')

def fit(self, epochs, cbs=None, reset_opt=False):
    # NEW: pass callbacks to fit() and have them removed when done
    self.add_cbs(cbs)
    # NEW: create optimizer on fit(), optionally replacing existing
    if reset_opt or not self.opt: self.opt = self.opt_func(self.splitter(self.model), lr=self.lr)
        
    try:
        self.do_begin_fit(epochs)
        for epoch in range(epochs):
            self.do_begin_epoch(epoch)
           ---->> if not self('begin_epoch'): self.all_batches()

Yes it’s a bug in 9b. Note that those notebooks are just there for teaching and don’t provide a maintained library (that will be v2, for in a few weeks) so those kinds of little bugs should be expected.

1 Like

Thanks @sgugger for the clarification!

I’ll roll back to 1.0x then for this and wait for the 2.0 release.