How do I not predict confidently on an Image which is not from the Classes that I trained the model On?

How do I not predict confidently on an Image which is not from the Classes that I trained the model On?

In the Pets classification from Lesson 1, when I try to Predict using the trained model on a Random object like say a “Building” it still predicts confidently as One of the “Dog breeds” from the Classes of the pet dataset http://www.robots.ox.ac.uk/~vgg/data/pets/

Do I have to create a comprehensive class with images other than these Dog breeds and call it “other”, and retrain the model?

Any advice will be much appreciated.

1 Like

I was about to ask the same thing. I built a classifier with two classes. Its pretty accurate but if i give any other image not related to these classes it still tries to predict that it’s one of the two classes.

Somebody please help!

Out of curiosity, what are the tensor values per class it spits out on a non-class image? (The array from learn.predict()?)

@muellerzr

I have 5 classes. when doing learn.predict() on an example image not from any of the these classes I get
[0.7390, 0.3234, 0.0736, 0.0071, 0.0786]

This is how classification works. The classifier learns to approximate the boundaries between the classes it was trained on. However, it doesn’t know where the boundary is among the set of trained classes and all the other (unknown, unseen) classes which might exist in the world.

Determining whether an item is of the known class versus “something else” is a separate problem in machine learning. You can Google “one-class classification” to learn more about that problem in general.

As a practical matter, I would suggest thinking about the real-world use cases of your application. Is anyone really going to upload a photo of a building to a site that detects dog breeds? Probably not, at least not with serious intent. But if there are specific types of things you think users might actually be confused about, you can price that into your design and add an “other” category with examples of those things

Finally, about your result from learn.predict(): this is normal. For a standard classification task, the neural network has a final softmax layer. Softmax will output a set of numbers where: a) the numbers add up to 1.0, and b) the numbers tend to be close to 0 or close to 1. Even if you input a photo of a blank piece of paper, you can be fairly sure the neural net will predict one of the dog breeds as a clear winner. Again, that’s because it’s designed to discriminate between examples of the types of things it was trained on. So when you give it something really different from the training data, you’re going to get a meaningless result.

2 Likes

Not much you can do about that. The model takes the input and maps it to an output based on the function it learned during training. The model was trained to deal with specific classes, and it assumes every input you give it maps to one of those classes.

One thing to note is that you’re likely getting an inflated sense of the model’s confidence. The prediction vector you showed is what you get after the model logits are passed through a softmax function. Softmax in general forces some value in the array to be close to 1.

You might get a different sense of the model’s prediction if you look at the raw logits and pass them through a sigmoid function instead.

@KarlH that is a great idea. How do I access the raw logits from the learn.predict()?

@shawn thanks so much for the explanation. I wonder if the final layer is not softmax but some other loss function which doesn’t foce the sums to be 1, wouldn’t that be a better choice?

Not really. If you think about what the input labels look like in a standard classification task, each label is actually a list of numbers, one for each possible class. In this list of numbers, the true class is exactly 1.0 and all the others are exactly 0. (This is known as “one hot encoding”.) So, softmax is actually a very good function to approximate the labels because it comes close to the same result. You could choose a different function that might give you less confident probabilities, but it would probably suffer a loss of classification accuracy. And either way, it’s not going to provide meaningful results if you try to predict something very different from the training data.

1 Like

Did you get an answer for how to get raw logits from learn.predict()?

Hello Everyone I hope your all having more fun than you can handle today!

I have been following A walk with fastai2 - Study Group and Online Lectures Megathread being run buy muellerzr and we covered a little bit about dealing with classes that your classifier hasn’t been trained on.

If you look at this notebook you will see a way of using multilabel classification to recognize images that your classifier has not been trained on.

It looked good to me.

Cheers mrfabulous1 :smiley: :smiley: