On_batch_begin callback

As an exercise for using callbacks, I’d like to write an on_batch_begin callback that simply passes the batch unchanged through.

@dataclass
class MyCallback(Callback):
    learn:Learner
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        return last_input, last_target

learn = Learner(data, model, callbacks=[MyCallback])
learn.fit(1, lr=1e-06)
# TypeError: on_train_begin() missing 1 required positional argument: 'self'

I’m probably missing something very basic. What is it?

You have to instantiate your callback: MyCallback(learn) and not just MyCallback.

1 Like

I tried to follow the GradientClipping example. Seems to work. Critique?

@dataclass
class MyCallback(LearnerCallback):
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        AssertionError("Lets see if it gets called") #Don't see this it when calling fit(), why?
        return last_input, last_target


learn = Learner(data, model, callback_fns=[partial(MyCallback)])
learn.callback_fns # to check if there

You just learned the differnece between callbacks and callback_fns :wink:

I hope so : )
Why do both callbacks and callback_fns exist? Learner.fit() seems to “simply” supply the both to fit() function.

callback_fns is there to easily pass callbacks that need the Learner object: instead of having to do:

learn = ...
cb = MyCallback(learn)
learn.callbacks.append(cb)

you can just pass MyCallback in callback_fns.

callbacks is still there for the callbacks that don’t use the Learner object.

2 Likes