I’m trying to write callback to visualize the predictions
of the validation set during training (in a way not provided by existing options such as wandb). I find from this thread and the Learner docs that get_preds
is the way to do this, and earlier threads seem to have a lot of instances of “that doesn’t apply to v2 anymore”, so…
Right now my little callback only does this much…
class VizPreds(Callback):
def before_train(self, **kwargs): print(f"{len(self.dls.valid.items)} items in validation dataset")
def after_epoch(self, **kwargs):
print("Epoch ended, starting predition")
preds,targs = self.learn.get_preds()
print(f"{len(preds)} items predicted")
…but it never completes. Even on a run with learn.fine_tune(1)
, I just see my text “Epoch ended, starting prediction” print over and over and over. It seems to be caught in an endless loop.
How do we check out our predictions during Training? I don’t see this in the Predictions callbacks (there’s only one routine listed), or in the Data callbacks, … I wasn’t sure if this was a “progress” thing, so I checked Progress callbacks too…
Stuck. Any help? Thanks.
Stealing some more code from the WandB callback, I ended up with something that doesn’t crash and finished exucting, but only generates 4 preds instead of the requested 53. Any ideas how to improve this?
class VizPreds(Callback):
remove_on_fetch,order = True,Recorder.order+1
# Record if watch has been called previously (even in another instance)
_viz_preds_watch_called = False
def __init__(self, n_preds=10000,seed=12345, reorder=True): store_attr()
def before_fit(self, **kwargs):
self.n_preds = min(self.n_preds, len(self.learn.dls.valid_ds))
print(f"\n{self.n_preds} preds to show")
myRandom = random.Random(self.seed) # For repeatability
self.n_preds = min(self.n_preds, len(self.dls.valid_ds))
idxs = myRandom.sample(range(len(self.dls.valid_ds)), self.n_preds)
test_items = [getattr(self.dls.valid_ds.items, 'iloc', self.dls.valid_ds.items)[i] for i in idxs]
self.valid_dl = self.dls.test_dl(test_items, with_labels=True)
self.learn.add_cb(FetchPredsCallback(dl=self.valid_dl, with_input=True, with_decoded=True, reorder=self.reorder))
def after_epoch(self, **kwargs):
print(f"Epoch ended, starting predition, predicting {self.n_preds}")
#preds,targs = self.learn.get_preds()
preds = self.learn.fetch_preds.preds
print(f"{len(preds)} items predicted")
learn = cnn_learner(dls, resnet18, metrics=error_rate, cbs=VizPreds)
learn.fine_tune(1)
54 preds to show
epoch train_loss valid_loss error_rate time
0 2.114636 1.949375 0.481481 00:39
Epoch ended, starting predition, predicting 54
4 items predicted
54 preds to show
epoch train_loss valid_loss error_rate time
0 0.391255 0.107325 0.037037 00:50
Epoch ended, starting predition, predicting 54
4 items predicted
Answer thanks to @zachmueller: preds = self.learn.pred
Yup! Although that this seems to give values in units of logits rather than probabilities. Or at least the triplets I’m getting (for 3 classes) don’t add up to 1. Will look into that.
Yea, those are logits, so they need to be run through F.softmax(preds, dim=1)
.
After a bit more work, here’s a basic schema I think I can keep working with from here on out… 
class VizPreds(Callback):
"Visualize predictions"
order = ProgressCallback.order+1
def before_fit(self, **kwargs): self.plot = MyPlotter(labels=self.dls.vocab)
def after_batch(self, **kwargs):
if not self.learn.training:
with torch.no_grad():
preds, targs = F.softmax(self.learn.pred, dim=1), self.learn.y
preds, targs = [x.detach().cpu().numpy().copy() for x in [preds,targs]]
self.plot.do_plot(preds, targs)
2 Likes
Thank you for sharing! It seems so useful and professional.
Why did you change after_epoch
to after_batch
?
Why did you use softmax()
?