learn.TTA() returns a 3-d dataset for log_preds

(Bahar) #1

When I run -> log_preds, y = learn.TTA(), I get a 3-d array for log_preds. I see from lecture videos that it should be a 2-d array. What am I missing?

Thank you in advance.

1 Like

Part 1: Lesson3 predictions from learn.TTA() return weird shape
(Jeremy Howard (Admin)) #2

Do a forum search - we’ve covered that a number of times.


(Bahar) #3

Thank you @jeremy for the quick response.


(Andrea de Luca) #4

Sorry @jeremy, but I searched the forum with a ton of different keywords, and was unable to find an answer to that question.

If I would have to guess at blind I’d say that every element of the third axis refers to single transformed versions of the same image.

If I’m wrong (and almost surely I am), could you provide a link to related posts?

1 Like

(Eduardo Poleo) #5

I’ve also been looking for while. My specific issue is with the dog breeds problem I’m getting a log_preds.shape of (5, 10357, 120) as opposed to the expected (10357, 120). What I think is happening is that we’re given 5 different (possible) sets of predictions. What I found some other people are doing is just getting the mean among all them, e.g log_preds,y = learn.TTA(); probs = np.mean(np.exp(log_preds), axis=0)
Like you can see in here Wiki: Lesson 2

1 Like

(Andrea de Luca) #6


It’s like I thought, then.


(Jeremy Howard (Admin)) #7

No need to apologize! :slight_smile: I just searched for ‘TTA’ and the first thing in the list was the announcement of this change: Change to how TTA() works . Does that answer your question? (Many of the other search results cover the same issue - so have a look at them too if you have a chance.)

Having said that, it looks like @eduardopoleo has given you a good solution there. (If you find a notebook where this fix hasn’t been applied, please either let me know so I can fix, or provide a PR, so students in the future don’t have to deal with this.)


(Andrea de Luca) #8

Thanks @jeremy. I think I have it clear now: indeed, playing with my monkeys, I looked at actual entries of a TTA() output by slicing it. Empirical observation confirms what @eduardopoleo said above :slight_smile:

1 Like