Developer chat

Breaking change

The API of the Callback has been slightly changed to be more flexible and (hopefully) less opaque. You now can return a dictionary that will update the state of the CallbackHandler so

  • instead of returning last_loss in on_backward_begin, you can return {"last_loss": my_value} in any event of any callback
  • you can change inner stats (for instance the number of the epoch displayed when resuming training)
  • in a custom metric written as a callback, instead of having to give last_metric=my_value in on_epoch_end, the new way of doing it is:
def on_epoch_end(self, last_metrics, **kwargs):
    return {'last_metrics': last_metrics + [my_value]}

The CallbackHandler will throw an error in case of typo (so you can’t add attributes).

The old return True in some events to stop training or skip the step are now replaced by the following flags in the state:

  • skip_step: to skip the next optimizer step
  • skip_zero: to skip the next grad zeroing
  • stop_epoch: to stop the current epoch
  • stop_training: to stop the training at the end of the current epoch

For instance, in the old LR Finder Callback the part to stop after going too high in loss was:

    def on_batch_end(self, iteration:int, smooth_loss:TensorOrNumber, **kwargs:Any)->None:
        "Determine if loss has runaway and we should stop."
        if iteration==0 or smooth_loss < self.best_loss: self.best_loss = smooth_loss
        self.opt.lr = self.sched.step()
        if self.sched.is_done or (self.stop_div and (smooth_loss > 4*self.best_loss or torch.isnan(smooth_loss))):
            #We use the smoothed loss to decide on the stopping since it's less shaky.
            self.stop=True
            return True

    def on_epoch_end(self, **kwargs:Any)->None:
        "Tell Learner if we need to stop."
        return self.stop

and the new one is

    def on_batch_end(self, iteration:int, smooth_loss:TensorOrNumber, **kwargs:Any)->None:
        "Determine if loss has runaway and we should stop."
        if iteration==0 or smooth_loss < self.best_loss: self.best_loss = smooth_loss
        self.opt.lr = self.sched.step()
        if self.sched.is_done or (self.stop_div and (smooth_loss > 4*self.best_loss or torch.isnan(smooth_loss))):
            #We use the smoothed loss to decide on the stopping since it's less shaky.
            return {'stop_epoch': True, 'stop_training': True}
3 Likes