Feature suggestion: Add 'with_decoded' option to `Learner.tta()` function

Hi! I was playing around with test time augmentation after watching lesson 06, and I realized the API for learner.tta(...) is a little different to what I expected after working with learner.get_preds(...) which is fantastic.

specifically, I found myself missing the with_decoded kwarg.

Is this for a specific reason, or is this something that fastai devs would be open to adding?

The idea is to prevent users of the method needing to do this lower level label calculations:

preds,_ = learn.tta(dl=tst_dl)
idxs = preds.argmax(dim=1)
vocab = np.array(learn.dls.vocab)
results = pd.Series(vocab[idxs], name="idxs")
# etc...

Something like so maybe:

preds, _, decoded = learn.tta(dl=test_dl, with_decoded=True)
results = decoded
# etc...

If so, I’d be tempted to try and add the feature myself, as my first contribution to the codebase. Let me know what you think.

I apologize in advance if this request is nonsensical, I’m still very new to Deep Learning and fastai :smiley:
Manu

1 Like