Low real-world accuracy

Hi everybody,

I’ve been experimenting on a project to help a friend. Right now, I’m only doing multi-label classification (as a kind of warmup), but I want to move on to instance segmentation.

I’ve run into a problem, which is that the real-world predictions given by the model are inaccurate, despite there being a high accuracy_multi score.

The real-world images are almost exactly the same format as the training images (taken with a special rig / setup), so I don’t think it’s a case of the test images being different.

I have a hunch that the problem might be because of the number of instances of each of the labels being heavily unbalanced.

I was wondering if you think that’s the problem too?

I’ve been blogging my work on this, and here’s the post I wrote, which gives my code and images:

I wanted to ask the community their opinion on balancing the data first, because doing it is quite a big task (unless there’s a fast way to do it that I’m unaware of) and if the real issue is unrelated, it’s not worth doing.

My second thought is: Is there a way to quickly balance the dataset, in either Google Sheets or Python? I’d love to know this if so.

First thing I notice is you have your threshold at .6 (or anything above 60% probability) for accuracy_multi. Try playing with this, perhaps increasing its threshold to 80% or higher. (This way it’s only when the model is really confident about its answers)

One other bit, can you post the raw output from learn.predict()? Not just the labels it generates. We can use this info to determine a threshold to try :slight_smile:

To add ontop of @muellerzr’s suggestion, I also dont see presizing in your code as mentioned by Jeremy and how this affects image quality. I see that batch calls a size of 512. From what Jeremy stated you should start at a higher size in item transforms so add item_tfms=Resize(something higher than 512) and see if this helps.

1 Like

And then one last suggestion, you mention imbalanced classes, look into oversampling and read this post by Rachel on why thinking about what you want your validation set to be is very important:

Thanks everybody for responding. I’ve been a little under the weather for the past week, but especially the past 12 hours. Sorry for the slow response.

Excellent idea – I’ll do this!

Sure thing! This is the output for the same image I wrote about at the end of the blog post:

((#5) ['biofouling','biofouling-plus-pink','lice-females','lice-males','lice-preadult'],
 tensor([ True,  True, False,  True,  True,  True, False, False, False]),
 tensor([9.4052e-01, 8.6763e-01, 2.2237e-01, 8.7788e-01, 9.8197e-01, 7.2497e-01,
         2.5964e-04, 3.3470e-03, 2.1405e-02]))

Interesting! Thanks for this tip. I’m a little confused though – Are you saying I should call item_tfms with one size (say 1024), and also keep batch_tfms at 512?

I thought this was only needed for when dealing with source images for varying sizes? In my case, all of the source images are the same size (4000px by 2672px).

Perhaps there’s an issue with the cropping though? Here is a raw source image, which isn’t square. Yet fastai presumably crops them to square?

Will do, thank you!

So if we read this output, biofouling had a prediction ‘strength’ of 94%, and lice-males has one of 98%. Considering those are the only two above 90%, you could probably just use that .1 as your threshold and be okay. Now you should use that metric in your accuracy too to represent that while training however

Also rest up :slight_smile:

Thanks for this help!

Ah-ha! I thought that 0.9 would be the appropriate threshold here?

I may have misunderstood, but I thought that 0.1 = 10% and 0.9 = 90%?

Would that be changing this:
learn = cnn_learner(dls, resnet34, pretrained=True, metrics=[partial(accuracy_multi, thresh=0.6)], loss_func = BCEWithLogitsLossFlat())

To this:
learn = cnn_learner(dls, resnet34, pretrained=True, metrics=[partial(accuracy_multi, thresh=0.9)], loss_func = BCEWithLogitsLossFlat())


Or do you also mean re-defining the BCEWithLogitsLossFlat function, to change the threshold there too?