I’m looking at the following code, which appears in the Dogs vs. Cats notebook:
%time log_predictions,y = learn.TTA()
When I look at log_predictions, it’s shape is (5, 2000, 2).
I get that the 2,000 is the number of samples, and the 2 is the prediction of what this sample is (a cat or a dog), but what is the 5 dimension? I assumed it was the different test time augmentations, but we only create 4 augmentations. So I guess it’s the prediction for the original image + 4 augmentations?
Hmm. I think that is indeed what’s happening, but now that I’m looking at the source, I don’t understand what it’s doing.
There’s a default arg,
n_aug=4, so I was going to say that yes, you’re getting 1+4=5 predictions. But that’s not actually how it works! For example, if you do
learn.TTA(n_aug=10), you’d think you’d get a shape of
(11, 2000, 2), but you get
(13, 2000, 2).
That happens because of the line
preds1 = [preds1]*math.ceil(n_aug/4). (10/4 ceils to 3.) I wonder if that’s just a bug?
Why are we doing this then
preds2 = [predict_with_targs(self.model, dl2) for i in tqdm(range(n_aug), leave=False)]
Also how does the recursive function ends in this case because
predict_with_targs keeps on calling itself?
I wonder if that’s needed to satisfy shapes .
Also the code isn’t calculating
mean anywhere but the docs says that it should…
That call to
predict_with_targs isn’t actually recursive: it’s delegating to a different, top-level function
model.py (since it’s not
As for not calculating the
mean, I’ve been looking through
learner.py's git history, and I think the function used to calculate the mean directly, but not anymore; it’s more flexible to just return all of the predictions.
Thanks for the replies. Yes, I noticed yesterday that the function no longer returns the mean and that to get the log loss you now have to call it as follows:
# Get the mean first
probabilities = np.mean(np.exp(log_predictions), 0)
# then get the log loss
I’m still not sure I can untangle why the shape is getting set as if though. I wish there was an easier way to debug the notebooks.
Perhaps @jeremy can explain where the 4 is coming from in this line?
preds1 = [preds1]*math.ceil(n_aug/4)
I make sure that around 20% of the images are not augmented. That’s why I concatenate a few copies of the original image to the list.