Is there anyway to call learn.get_preds() without triggering any of the callbacks?

I know I can do something like this …

with torch.no_grad():
    for index, b in enumerate(dl):
        preds = learn.model(b)

… but would like to use learn.get_preds() if there is way to do so.

Why do I want to do this?

Because I need to be able to get the predictions against the validation set inside a Callback.after_validate method to get a key metric to be used by several metrics. Right now, calling learn.get_preds() results recursive loop.

If you want to call get_preds inside a Callback you need to use the callback FetchPreds, which is designed for that (and allows you to disable some callbacks). I’m not too sure what you use case is though.
Note that you have a learn.removed_cbs() context manager that can be useful too.

The problem is that FetchPreds runs after_validate.

I’m trying to get the optimal threshold for a multi-label model model … a value to be used in several metrics associated to my learner. I need to do this in the before_validation method. This is what I tried in my Callback:

def begin_validate(self, **kwargs):
    with self.learn.removed_cbs(self) as learn:
            self.preds_targs = learn.get_preds(inner=True)
    probs, targs = self.preds_targs

… and here’s the fun stack trace:

I should mention that this Callback is set to run before the Recorder so I can update the threshold in the metrics that have a thresh attribute.

Also a weird behavior (at least to me). If I change the code:

def begin_validate(self, **kwargs):
    with self.learn.removed_cbs( as learn:
            self.preds_targs = learn.get_preds(inner=True)
    probs, targs = self.preds_targs

so that I remove all the learner’s callbacks … I get this back before I execute the with block:
# returns [TrainEvalCallback,Recorder,ProgressCallback,SaveModelCallback,FetchPreds,OptimizeFBetaThreshCallback,ModelReseter,RNNRegularizer,ParamScheduler]

and then when I’m in the with block I’m expected that learn has no callbacks … BUT it does:
# returns [Recorder,SaveModelCallback,OptimizeFBetaThreshCallback,RNNRegularizer]

I’m obviously missing something.


Well, this was a fun one.

The problem was that I was using SaveModelCallback which derives from TrackerCallback … and TrackerCallback attempts to access self.recorder in both of the lifecycle events it includes.

So long story short, if you’re going to remove Recorder callback … you probably want to remove all TrackerCallback instances as well.


Thanks for sharing. I’ll probably need to do that as well in the FetchPreds for wandb .

I’ll submit some PRs as well if you don’t mind today or tomorrow with a new Metric subclass, improved Callback class, and my optimized threshold callback. I think they are framework worthy.

hi @wgpubs
Could you post your working code to solve it? I’m also struggling with a custom callback to get predictions during training.

Instead of adding your callback to Learner … if it is simply used for training, just include it in your call(s) to fit or fit_one_cycle. As the callback is no longer associated to your Learner, they won’t interfere with your call to get_preds()

In terms of examples, you can check out one of the custom callbacks I include in my blurr library (see here). Not sure what you’re doing specifically, but this example callback I think provides a taste of what you can do and also not get in the way of the fastai bits.

Lmk how it goes.


Now it’s working :slight_smile: But I didn’t use get_preds(). To make it work, I had to go with:
with torch.no_grad(): preds = learn.model(xb)

The training was ok but now I have another problem ahahah. I can’t export the model. It keeps giving me:

TypeError: cannot serialize ‘_io.TextIOWrapper’ object

Maybe it’s because of CSVLogger?
But not only that. When I try to make a prediction (after training), my custom callback kicks in! I’ve never seen this happen.

Are you passing in your callback to Learner or to Fit? If you don’t want it permanently on there, pass the callback only while fitting.

Oh let me try that!
Should I pass all callbacks to fit?

Edit: it seems to be working.
No errors during export and my custom callback is silent :slight_smile: I can also load the model back for inference. Thank you @muellerzr!

What is the “xb” in your code? Would you mind posting a link to the code that finally worked?

xb = x batch. When working in the callbacks you have self.x and self.xb. You really want to modify self.xb, and self.x is there for convience


I am stuck with a custom callback that involves get_preds.

This is what happens (the table with loss/accuracy, etc. doesn’t show):

It should look like this (here I removed f1_cb from cbs=[])

This is the callback
get_best_micro_f1() is a function that includes learn.get_preds()

If I uncomment the get_preds() part in get_best_micro_f1() everything works fine

Also the CSVLogger ouptut looks a bit strange - it somehow shows up double and doesn’t match the columns



Can someone give a simple code example how FetchPredsCallback is used to access the predictions while training? This is the only thread in the forums where it’s mentioned.


I figured it out!!!

It worked perfectly when I changed after_train to after_epoch

class ExampleFetchCallback(FetchPredsCallback):
    def __init__(self):           
        self.ds_idx = 1
        self.dl = None

    def after_epoch(self):
        preds,_ = self.preds
        ... # do something with preds

cb = ExampleFetchCallback()

In the learner cbs = [cb]