is_bird,_,probs = learn.predict(PILImage.create('bird.jpg'))
print(f"This is a: {is_bird}.")
print(f"Probability it's a bird: {probs[0]:.4f}")
#my added code
is_forest,_,probs = learn.predict(PILImage.create('forest.jpg'))
print(f"This is a: {is_forest}.")
print(f"Probability it's a forest: {probs[0]:.4f}")
The bird is predicted successfully at almost 1.0. The forest is predicted terribly at 0.0000. How do I get this model to work for both the forest and the bird? And if I wanted to scale it to four categories, how would I do that?
You have your probs variable. It contains two rows (if itâs not rows, itâs columns then): the first row contains the probability that the image contains a bird, and the second row contains the probability that the image contains a forest.
In your code snippet above, youâre accessing the first row in both cases.
print(f"Probability it's a bird: {probs[0]:.4f}")
Here, youâve correctly accessed the first row, which tells us the probability of a bird being in the image.
print(f"Probability it's a forest: {probs[0]:.4f}")
Here, youâve accessed the first row again. The 0.0000 prediction youâre speaking of is the model telling us that there is a 0% chance the image is a bird!
To fix this, simply change the index to access the second row.
As for scaling it to four categories, I think itâs as simple as increasing the number of search terms, searches, in this specific notebook.
Where can I read about the documentation of .predict? I though it would be in Vision Learner, becuase of âlearn = vision_learner(dls, resnet18, metrics=error_rate)â but it isnât.
Also, letâs say I have 20 image categories. How would I keep track of which index is which category?
I figured out how the indices work. They are alphabetical order of the folder names. Not documented anywhere AFAIK.
So if I have categories = âbananaâ, âcherryâ, âappleâ then dl the images and run the other ML code
probs[0] is apple
probs[1] is banana
probs[2] is cherry
Code like this is how you get your sorted directories.
def get_subdirectories(directory):
return [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]
directories=get_subdirectories('lesson1/training_images')
#non-destructive
sorted_dirs = sorted(directories)
#destructive
directories.sort()