How to use Test Time Augmentation (TTA) with `DataBlock`?

Chapter 7 of the book explains the concept of TTA, and provides the following code:

preds, targs = learn.tta()
accuracy(preds, targs).item()

While I understand the intuition behind TTA, it’s not explained how to use TTA to train a better model.

How do I use TTA with DataBlock?

The documentation ( doesn’t provide an example. An example would really help.

Thank you.

You don’t use TTA to train a better model. You use TTA to get additional performance improvements from an already-trained model.

Do you know how to use learn.get_preds()? The learn.tta() function is used in the same way…


Oh I see! All this while, I was thinking it’s used in the validation step in the training loop.

Thanks for clarifying Tanishq. :slight_smile:

1 Like