How do I figure out which category the Learner.predict returns in its prediction array?

Learner.predict returns an array of prediction probabilities as its 3rd element in the tuple representing each category in the data block. I can’t tell, however, which element in the array corresponds to each category.

Is there a way to do this programmatically?

Figured it out. I changed the example:

player,_,probs = learn.predict(PILImage.create('newyorkjetsplayer.jpg'))
print(f"This is a: {player}.")
print(f"Probability: {probs[0]:.4f}")

to:

player,index,probs = learn.predict(PILImage.create('newyorkjetsplayer.jpg'))
print(f"This is a: {player}.")
print(f"Probability: {probs[index]:.4f}")
1 Like