Moving registering `callback_fns` with Learner from `.fit()` to `__post_init__()`

Note: this is an issue that I opened on GitHub. I thought this would be a straightforward change, but it isn’t. I am not sure how I should proceed with this issue, because I have only made one minor contribution fo fastai before. I would like to work on this, but need some guidance and advice. So far I have identified two problems:

  1. Callback tests assume that callbacks will be registered by fit() or .fit_one_cycle() functions, like in the example below. Not that the callback is added to callback_fns after creating the learner. This callback_fn will be registered with the learner when .fit_one_cycle() is called:
def test_logger():
    learn = fake_learner()
    learn.metrics = [accuracy, error_rate]
    learn.callback_fns.append(callbacks.CSVLogger)
    this_tests(callbacks.CSVLogger)
    with CaptureStdout() as cs: learn.fit_one_cycle(3)
    csv_df = learn.csv_logger.read_logged_file()
    stdout_df = convert_into_dataframe(cs.out)
    pd.testing.assert_frame_equal(csv_df, stdout_df, check_exact=False, check_less_precise=2)
    recorder_df = create_metrics_dataframe(learn)
    # XXX: there is a bug in pandas:
    # https://github.com/pandas-dev/pandas/issues/25068#issuecomment-460014120
    # which quite often fails on CI.
    # once it's resolved can change the setting back to check_less_precise=True (or better =3), until then using =2 as it works, but this check is less good.
    csv_df_notime = csv_df.drop(['time'], axis=1)
    pd.testing.assert_frame_equal(csv_df_notime, recorder_df, check_exact=False, check_less_precise=2)
  1. In this case, the Recorder callback needs to access the opt property of the learner, which is yet unavailable when the learner is first initialized, so moving this callback’s registering to learner initialization breaks code. When this callback is registered with the learner in the .fit() function, the optimizer is already available:
callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)

Issue as it appears on GitHub

TLDR: when trying to run .get_preds() right after creating a learner, callback_fns are not working because they are registered with the learner in the .fit() function. I propose to move registering callback_fns to initialization of learner so that callbacks can work even without training first.

I am trying to use a pretrained model and analyze activations for some data; in other words, I am not using training, only predicting from an existing model.

I create the data, then the learner, and a test hook callback that prints when a batch is done:

class TestHook(HookCallback):
    def hook(self, m, i, o):
        pass
    def on_batch_end(self, train, **kwargs):
        print("Batch Done!")
        
learner = cnn_learner(data, models.resnet18, callback_fns=TestHook)

When I run learner.fit_one_cycle(1), the callback is invoked after each batch and `“Batch Done!” is printed correctly.

However, when I run learner.get_preds(data.train_ds), nothing happens. After digging into the source code, I found the reason: callbacks are registered with the learner in the fit() function. This means that after creating the learner, callbacks are still not registered and therefore will not work when running .get_preds() immediately after creating the learner.

I was able to solve this the following way, and callback works correctly, but this looks too hacky:

class TestHook(HookCallback):
    def hook(self, m, i, o):
        pass
    def on_batch_end(self, train, **kwargs):
        print("Batch Done!")     
        
learner = cnn_learner(data, models.resnet18)
test_hook_callback = StoreHook(learner)
learner.callbacks += [test_hook_callback]

learner.get_preds(data.train_ds)

Please let me know if this is worth pursuing. If yes, I can try to change the code and create a pull request.

Some more contexts is in this post of fastai forums.

1 Like