Hi there.
Sometimes, I want to disable some callback during the some epochs. For example, it doesn’t make sense to use SaveModelCallaback
in the firsts epochs as later you will have a better model. This is only one example.
Here is a Mixin to accomplish that. Note that you can pass any argument and keyword argument to the original callback and, optionally, it simulates callback setup when it’s enable (simulate_cb_setup=True
).
class DisableCallbackMixin:
"""Callback will be disabled (it won't receive any callback event) until disable_cb_until_pct_train` is reached."""
def __init__(self, *args, disable_cb_until_pct_train=.5, simulate_cb_setup=True, **kwargs):
super().__init__(*args, **kwargs)
self.disable_cb_until_pct_train, self.simulate_cb_setup = disable_cb_until_pct_train, simulate_cb_setup
self._cb_disabled = True
def _perform_event(self, event):
# Only perform the event if callback is enabled
if not self._cb_disabled and hasattr(super(), event): getattr(super(), event)()
def _initialize_cb(self):
# Simulate skipped setup events
if hasattr(super(), 'begin_fit'): super().begin_fit()
if hasattr(super(), 'begin_epoch'): super().begin_epoch()
if hasattr(super(), 'begin_train'): super().begin_train()
# Intercept any callback event so it could be dropped
def begin_fit(self): self._perform_event('begin_epoch')
def begin_epoch(self): self._perform_event('begin_epoch')
def begin_train(self): self._perform_event('begin_train')
def begin_batch(self):
if self._cb_disabled and self.pct_train >= self.disable_cb_until_pct_train:
self._cb_disabled = False
if self.simulate_cb_setup: self._initialize_cb()
self._perform_event('begin_batch')
def after_pred(self): self._perform_event('after_pred')
def after_loss(self): self._perform_event('after_loss')
def after_backward(self): self._perform_event('after_backward')
def after_step(self): self._perform_event('after_step')
def after_batch(self): self._perform_event('after_batch')
def after_train(self): self._perform_event('after_train')
def begin_validate(self): self._perform_event('begin_validate')
def after_validate(self): self._perform_event('after_validate')
def after_epoch(self): self._perform_event('after_epoch')
def after_fit(self):
self._perform_event('after_fit')
self._cb_disabled = True
Helper function to convert any existing callback so it’s disable at the begining:
def create_cb_disabled(cb_class:type, *args, **kwargs):
"""Delegate args and kwargs to `DisableCallbackMixin` and `cb_class` constructors"""
class _inner_cls(DisableCallbackMixin, cb_class): pass
return _inner_cls(*args, **kwargs)
Example: disable SaveModelCallaback
during the first 50% training iterations that will monitor ‘train_loss’ with min_delta
of .2:
# Here, Mixin order is important
class DisabledSaveModelCallback(DisableCallbackMixin, SaveModelCallback): pass
cb = DisabledSaveModelCallback('train_loss', disable_cb_until_pct_train=.5, min_delta=.2)
or
cb = create_cb_disabled(SaveModelCallback, 'train_loss', cb_until_pct_train=.5, min_delta=.2)
The only limitation that I’m aware is that if you have two callbacks that modifies the same attribute , you need to ensure that the order is: Normal callback --> Disabled callback in order to perform the callabck clean up correctly.
For example MixUp and CutMix callbacks change the loss function. If you have this order: DisabledCutMix --> MixUp, MixUp begin_fit will be called before CutMix begin_fit. However, fastai believe that is the opposite. So, after_fit event will be called in the incorrect order.
I don’t know how to be sure that a callback is always at the end of the callbacks list . Any suggestion is welcome.
¿Do you think that fastai2 should have it builtin? I would do a PR