I tried to follow the GradientClipping example. Seems to work. Critique?
@dataclass
class MyCallback(LearnerCallback):
def on_batch_begin(self, last_input, last_target, **kwargs):
AssertionError("Lets see if it gets called") #Don't see this it when calling fit(), why?
return last_input, last_target
learn = Learner(data, model, callback_fns=[partial(MyCallback)])
learn.callback_fns # to check if there