Getting preds within a Callback: not get_preds()?

I’m trying to write callback to visualize the predictions
of the validation set during training (in a way not provided by existing options such as wandb). I find from this thread and the Learner docs that get_preds is the way to do this, and earlier threads seem to have a lot of instances of “that doesn’t apply to v2 anymore”, so…

Right now my little callback only does this much…

class VizPreds(Callback):
    def before_train(self, **kwargs): print(f"{len(self.dls.valid.items)} items in validation dataset")
    def after_epoch(self, **kwargs):
        print("Epoch ended, starting predition")
        preds,targs = self.learn.get_preds()
        print(f"{len(preds)} items predicted")

…but it never completes. Even on a run with learn.fine_tune(1), I just see my text “Epoch ended, starting prediction” print over and over and over. It seems to be caught in an endless loop.

How do we check out our predictions during Training? I don’t see this in the Predictions callbacks (there’s only one routine listed), or in the Data callbacks, … I wasn’t sure if this was a “progress” thing, so I checked Progress callbacks too…

Stuck. Any help? Thanks.

Stealing some more code from the WandB callback, I ended up with something that doesn’t crash and finished exucting, but only generates 4 preds instead of the requested 53. Any ideas how to improve this?

class VizPreds(Callback):
    remove_on_fetch,order = True,Recorder.order+1
    # Record if watch has been called previously (even in another instance)
    _viz_preds_watch_called = False
    def __init__(self, n_preds=10000,seed=12345, reorder=True): store_attr()
    def before_fit(self, **kwargs): 
        self.n_preds = min(self.n_preds, len(self.learn.dls.valid_ds))
        print(f"\n{self.n_preds} preds to show")
        myRandom = random.Random(self.seed)  # For repeatability
        self.n_preds = min(self.n_preds, len(self.dls.valid_ds))
        idxs = myRandom.sample(range(len(self.dls.valid_ds)), self.n_preds)
        test_items = [getattr(self.dls.valid_ds.items, 'iloc', self.dls.valid_ds.items)[i] for i in idxs]
        self.valid_dl = self.dls.test_dl(test_items, with_labels=True)
        self.learn.add_cb(FetchPredsCallback(dl=self.valid_dl, with_input=True, with_decoded=True, reorder=self.reorder))
        
    def after_epoch(self, **kwargs):
        print(f"Epoch ended, starting predition, predicting {self.n_preds}")
        #preds,targs = self.learn.get_preds()
        preds = self.learn.fetch_preds.preds
        print(f"{len(preds)} items predicted")
learn = cnn_learner(dls, resnet18, metrics=error_rate, cbs=VizPreds)
learn.fine_tune(1)
54 preds to show
epoch	train_loss	valid_loss	error_rate	time
0	2.114636	1.949375	0.481481	00:39
Epoch ended, starting predition, predicting 54
4 items predicted

54 preds to show
epoch	train_loss	valid_loss	error_rate	time
0	0.391255	0.107325	0.037037	00:50
Epoch ended, starting predition, predicting 54
4 items predicted

Answer thanks to @zachmueller: preds = self.learn.pred

Yup! Although that this seems to give values in units of logits rather than probabilities. Or at least the triplets I’m getting (for 3 classes) don’t add up to 1. Will look into that.

Yea, those are logits, so they need to be run through F.softmax(preds, dim=1).
After a bit more work, here’s a basic schema I think I can keep working with from here on out… :slight_smile:

class VizPreds(Callback):
    "Visualize predictions"
    order = ProgressCallback.order+1
    def before_fit(self, **kwargs): self.plot = MyPlotter(labels=self.dls.vocab)
    def after_batch(self, **kwargs):         
        if not self.learn.training:
            with torch.no_grad(): 
                preds, targs = F.softmax(self.learn.pred, dim=1), self.learn.y
                preds, targs = [x.detach().cpu().numpy().copy() for x in [preds,targs]]
                self.plot.do_plot(preds, targs)
1 Like