Help understanding and writing custom Callbacks

I’m writing a little toy example to better understand custom callbacks and ran into a few questions. I wrote the following that saves model weights after each epoch.

@dataclass
class SaveModel(LearnerCallback):
    """Save Latest Model"""
    def __init__(self, learn:Learner):
        super().__init__(learn)
        
    def on_epoch_end(self, epoch:int, **kwargs):
        dt = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
        self.learn.save('my_model_' + dt)
        return False

learn = RNNLearner.language_model(data_lm, pretrained_model=URLs.WT103, callback_fns=SaveModel)
learn.fit(2, 1e-2)

I’m getting the expected output – which is the model weights to be saved to disk after each epoch. But my questions are:

  1. Is it possible to pass in function arguments to SaveModel when calling the callback? So as a trivial example, what if I wanted to pass a model name prefix into on_epoch_end while creating learn. Where would I do this?
  2. What is the difference between passing callback_fns and callbacks into the learner object?
1 Like

If you need to pass any parameter to your callback, you should do so in the init:

def __init__(self, learn:Learner, arg):
    super().__init__(learn)
    self.arg = arg

and then you can access this argument through self.arg in every function.

A callback function is a Callback that takes a Learner, it will be created when you call fit and at the same time, become an attribute of the Learner like learn.recorder which can be useful if you want to access it later.

1 Like

Sorry, I think I’m still not quite understanding. So where would I pass in arg? I’ve tried like this:

learn = RNNLearner.language_model(data_lm, pretrained_model=URLs.WT103, callback_fns=SaveModel, arg='arg')

which gives me TypeError: __init__() got an unexpected keyword argument 'model_name'

And like this:

learn.fit(2, 1e-2, arg='arg')

and I get TypeError: fit() got an unexpected keyword argument 'model_name'

Not sure if this is the best way to do this or not, but actually works if I do something like this instead:

learn = RNNLearner.language_model(data_lm, pretrained_model=URLs.WT103)
my_cb = SaveModel(learn, arg='some arg')
learn.fit(2, 1e-2, callbacks=my_cb)
1 Like

Yes, that is the right way to do it.

The docs mention that:

Every callback that is passed to Learner with the callback_fns parameter will be automatically stored as an attribute. The attribute name is snake-cased, so for instance ActivationStats will appear as learn.activation_stats (assuming your object is named learn ).

I created a simple custom callback like that:

@dataclass
class SimpleRecorder(LearnerCallback):
    learn:Learner
    def on_train_begin(self, **kwargs:Any)->None:
        self.losses = []
    def on_step_end(self, iteration: int, last_loss, **kwargs ):
        self.losses.append(last_loss)
    def on_epoch_end(self, last_loss, smooth_loss, **kwarg):
        print('Epoch ended', last_loss, smooth_loss)
    def plot(self, **kwargs):
        losses = self.losses
        iterations = range(len(losses))
        fig, ax = plt.subplots(1,1)
        ax.set_ylabel('Loss')
        ax.set_xlabel('Iteration')
        ax.plot(iterations, losses)

I add it via the callback_fns parameter:

learn = cnn_learner(data, 
                    models.resnet34, 
                    metrics=error_rate, 
                    callback_fns=[
                        ShowGraph,
                        SimpleRecorder
                    ])

But can’t access it via learn.simple_recorder. However, I can do learn.show_graph. What am I missing?

LearnerCallback is not a dataclass anymore, and by using this syntax you are bypassing LearnerCallback init (where it sets the attribute). Just do

class SimpleRecorder(LearnerCallback):
    def on_train_begin(self, **kwargs:Any)->None:
    ...

Ok. I removed the @dataclass and the learner field but still receive an AttributeError.

I also tried subclassing the Recorder:

class NewRecorder(Recorder):
    def on_epoch_end(self, **kwargs):
        print('Epoch ended')

But also can’t access the learn.new_recorder attribute.