But it results in an array of the following shape:

(2, 2044, 120)

There’s not enough space in this post for me to go over in detail everything I’ve done, so I’ve attached a copy of my iPython notebook to see if anyone can point out where I’ve screwed up here.dogbreeds.pdf (157.4 KB)

Btw, I ended up with an extra 3rd dimension in my resulting array… basically 5 different prediction arrays stacked on top of each other. I did try taking the first of the stack and using that as my answer and submitting it to Kaggle:

ds = pd.DataFrame(probs[0]) #Take the 1st one out of the 5
ds.columns = data.classes
ds.insert(0, 'id', [o[5:-4] for o in data.test_ds.fnames])
SUBM = f'{PATH}subm/'
os.makedirs(SUBM, exist_ok=True)
ds.to_csv(f'{SUBM}subm.gz', compression='gzip', index=False)
FileLink(f'{SUBM}subm.gz')

And that ended up with an error of about 0.25 and in about 300th place so I have a feeling it’s not the right answer.

It’s natural to have that result because you typed ‘is_test = True’. It orders our model to predict each 5 predictions from 5 different augmented test_data. So you need to calculate the average probability for each test_set and make a submission from it.

The outcome is for validation set, I guess.
You can use this how accurate your model is for the validation set as the same way I mentioned above making it 2 dimension matrix.

like @KyunghwanKim said, TTA will apply transforms to every picture in your data set and then give you results for each transformed image.

this will take the mean of the 5 transforms and give you a final result.

log_preds, y = learn.TTA(is_test = True)
probs = np.mean(np.exp(log_preds), axis=0)

the True in learn.TTA(True) isn’t doing what you would expect. if you look in the learner.py source file you’ll see TTA is defined as follows.

def TTA(self, n_aug=4, is_test=False):

in this case. n_aug (number of augmentations) is being set to True which can be interpreted as 1 if it’s used as an int. so you are getting 2 results. the original image and 1 augmented image. and since is_test=False (because you arent setting it). it is using the validation data set.