learn.TTA() error

Hi @jeremy, I get this error while running learn.TTA():

And when I printed the shape of the output (see figure above), I’ve noticed that the mean of the augmented data(n_aug=4 by default) + original data (which explain the 5 on axis 0) is missing because the comment mentions it:

Additional to the original test/validation images, apply image augmentation to them
(just like for training images) and calculate the mean of predictions. The intent
is to increase the accuracy of predictions by examining the images using multiple
perspectives.

I did it and it works now fine.
So I would like to know if I have to send a PR for that quick fix.

The comment in the code is now wrong since we don’t take the mean any more. If you search the forum you’ll find the reasoning being this.

Thank you so much, sorry I’ve not searching on the forum.