How to report the best performing model?

I’ve been struggling to perform a very basic task for weeks now.

I want to train a model using some loss function, but determine which is best based on some other metric evaluated on the validation set. I furthermore want that best performance value to be logged to Weights and Biases.

Here’s what I’ve got so far:

cbs = [TrackerCallback(monitor='valid_performance'), # TODO: is this one needed if I also have a SaveModelCallback?
    SaveModelCallback(monitor='valid_performance'),
    WandbCallback()]

learn = vision_learner(dls, resnet18, metrics=[performance], cbs=cbs, ...)

learn.recorder.train_metrics = True
learn.recorder.valid_metrics = True # This way WandB will record a valid_performance 

learn.fine_tune(num_epochs, cbs=cbs)

(The docs of the WandbCallback say: “If used in combination with SaveModelCallback, the best model is saved as well (can be deactivated with log_model=False).” )

I could spend some time writing custom callbacks, but I can’t imagine that such basic functionality is not natively supported in the library.

I would expect that the FastAI library supports basic functionality to automatically output a ‘best_valid_performance’, so that I can sort all of my runs on the best performing model they found, rather than the last model in the run.

How can we upload only the best performing model to W&B and record a ‘best_valid_performance’?

A potentially useful PR would be to modify SaveModelCallback to save all the metrics from the best epoch so they can easily be grabbed and logged later :smile:.

I’ve implemented an easy wrapper callback which logs the best value of an existing tracker callback. That way you can log the best value of any TrackerCallback - not just the SaveModelCallback.

class WandbTrackerCallback(Callback):
    def __init__(self, tracker: TrackerCallback):
        self.tracker = tracker

    def after_fit(self):
        if not self.tracker.best is None:
            wandb.log({f'best_{self.tracker.monitor}': self.tracker.best})

You can call it like so:

save_model_cb = SaveModelCallback(monitor='valid_performance', comp=np.less, reset_on_fit=False)
cbs = [
    save_model_cb,
    WandbCallback(log_model=True, log_preds=True),
    WandbTrackerCallback(save_model_cb),
]

PS: Downside is that the number of callbacks is duplicated. I wasn’t able to write a wrapper which diverts all callbacks to the baseclass, like WandbTrackerCallBack(SaveModelCallback(...)). It seems to break on some property recorder which my class then doesn’t have