Code snippet to disable any callback during the first X training iterations

Hi there.

Sometimes, I want to disable some callback during the some epochs. For example, it doesn’t make sense to use SaveModelCallaback in the firsts epochs as later you will have a better model. This is only one example.

Here is a Mixin to accomplish that. Note that you can pass any argument and keyword argument to the original callback and, optionally, it simulates callback setup when it’s enable (simulate_cb_setup=True).

class DisableCallbackMixin:
    """Callback will be disabled (it won't receive any callback event) until disable_cb_until_pct_train` is reached."""
    
    def __init__(self, *args, disable_cb_until_pct_train=.5, simulate_cb_setup=True, **kwargs):
        super().__init__(*args, **kwargs)
        self.disable_cb_until_pct_train, self.simulate_cb_setup = disable_cb_until_pct_train, simulate_cb_setup
        self._cb_disabled = True
        
    def _perform_event(self, event):
        # Only perform the event if callback is enabled
        if not self._cb_disabled and hasattr(super(), event): getattr(super(), event)()        

    def _initialize_cb(self):
        # Simulate skipped setup events
        if hasattr(super(), 'begin_fit'): super().begin_fit()
        if hasattr(super(), 'begin_epoch'): super().begin_epoch()
        if hasattr(super(), 'begin_train'): super().begin_train()
        
    # Intercept any callback event so it could be dropped
    def begin_fit(self): self._perform_event('begin_epoch')
    
    def begin_epoch(self): self._perform_event('begin_epoch')
    
    def begin_train(self): self._perform_event('begin_train')

    def begin_batch(self): 
        if self._cb_disabled and self.pct_train >= self.disable_cb_until_pct_train:
            self._cb_disabled = False
            if self.simulate_cb_setup: self._initialize_cb()
        
        self._perform_event('begin_batch')

    def after_pred(self): self._perform_event('after_pred')

    def after_loss(self): self._perform_event('after_loss')

    def after_backward(self): self._perform_event('after_backward')

    def after_step(self): self._perform_event('after_step')

    def after_batch(self): self._perform_event('after_batch')

    def after_train(self): self._perform_event('after_train')

    def begin_validate(self): self._perform_event('begin_validate')

    def after_validate(self): self._perform_event('after_validate')

    def after_epoch(self): self._perform_event('after_epoch')

    def after_fit(self): 
        self._perform_event('after_fit')
        self._cb_disabled = True

Helper function to convert any existing callback so it’s disable at the begining:

def create_cb_disabled(cb_class:type, *args, **kwargs):
    """Delegate args and kwargs to `DisableCallbackMixin` and `cb_class` constructors"""
    class _inner_cls(DisableCallbackMixin, cb_class): pass
    return _inner_cls(*args, **kwargs)

Example: disable SaveModelCallaback during the first 50% training iterations that will monitor ‘train_loss’ with min_delta of .2:

# Here, Mixin order is important
class DisabledSaveModelCallback(DisableCallbackMixin, SaveModelCallback): pass
cb = DisabledSaveModelCallback('train_loss', disable_cb_until_pct_train=.5, min_delta=.2)

or

cb = create_cb_disabled(SaveModelCallback, 'train_loss', cb_until_pct_train=.5, min_delta=.2)

The only limitation that I’m aware is that if you have two callbacks that modifies the same attribute , you need to ensure that the order is: Normal callback --> Disabled callback in order to perform the callabck clean up correctly.

For example MixUp and CutMix callbacks change the loss function. If you have this order: DisabledCutMix --> MixUp, MixUp begin_fit will be called before CutMix begin_fit. However, fastai believe that is the opposite. So, after_fit event will be called in the incorrect order.

I don’t know how to be sure that a callback is always at the end of the callbacks list :confused: . Any suggestion is welcome.

¿Do you think that fastai2 should have it builtin? I would do a PR

You can simply say callback.run=False to deactivate it for a bit (and set it back to True when you want it to run).