How to deal with unbalanced data classes in image classification?

Hi everyone,
I have a dataset of 500000 products images divided in 5000 category labels and i used fastai library to train a deep learning model to classify these images in different categories. After the train i got a high very high accuracy (99%) and a very low loss(0,03) . It seemed to me akward since i got these results doing the minimum config. After looking in the data i figured out that some classes(categories) have just 1 image, so i think that my model isn’t good and the results are biased due to unbalanced dataset . What should i do in this case and what techniques to use in order to improve my model ?

1 Like

Hi there,

I’m probably late…did you find a solution? I’m strugging with a similar problem. At some point in the beginning of the course (2018 version) someone asked Jeremy that and he said that a common approach is to oversample the less frequent classes, i.e., you duplicate some training examples (he also said, that we’d be talking more about that in future lessons). I’m now trying that but still having problems (validation loss goes up whereas the training loss goes down => overfitting).

Cheers.

Over sampling under represented classes, and under sampling over represented classes are both options, the latter if you have enough data.

You can also try class weighting the loss function. Poor predictions of under weighted classes are penalised more heavily in the loss function. Something along the lines of
weights = [w1, w2, w3, ...]
class_weights = torch.FloatTensor(weights)
learn.crit = nn.CrossEntropyLoss(weight=class_weights)
where you weight w1, w2, w3 in whatever way you wish. I’ve worked with sample size / (num classes * class frequency) but have to admit I’ve not had much luck getting class weighting to work well.

Another approach is to initially train on an unbalanced dataset and then fine tune with a balanced dataset, though I can’t see that working always.

It all depends what you deem ‘success’. The OP seems to feel 99% accuracy isn’t ‘success’ if the 99% are all in a few classes out of thousands. In such a case I would definitely under sample the very heavily represented class or make the loss function care less about those classes. In other use cases, that result may be considered a roaring success.

Personally, I have found that networks are quite resilient to imbalances of even 10:1 or more and any fiddling about with weights does more harm than good.

Whether your overfitting has much to do with class imbalance, I am doubtful.

8 Likes

I have an ibalance of about 200:1 :confused: and I guess I’m on the same boat as the OP: I care the most about pinpointing samples from the rare class (I fix the specificity at, say 0.9, and then try and maximize the sensitivity).

I didnt’ know about this one…interesting (although you don’t seem so confident :sweat_smile: )

I didnt’ know about this one…interesting (although you don’t seem so confident

I am not sure of real world application, but it has worked in kaggle competitions where the test set has a much more balanced class distribution than the training set. If I am accurately reading our very own radek’s winning entry account in the Kaggle fashion competition, this is the approach he used. It was also the approach used by the winners in the 2017 inaturalist competition, they trained on an unbalanced set and finetuned on a balanced set.

Across in the furniture competition, the winning entry addressed it by adjusting predictions through applying bayesian theorem to the imbalance.

1 Like

Thank you so much for all the pointers!!

Have a read of this post. I found this to be a good introduction to solving the problem.

https://machinelearningmastery.com/tactics-to-combat-imbalanced-classes-in-your-machine-learning-dataset/

@digitalspecialists does the wights have to add up to 1 when you use this approach?

1 Like

To evaluate the classification performance of your model, I would also recommend to use metrics like Kappa and AUC that are not directly related to class prevalence.

Accuracy, precision, recall and F1 are all directly related to class prevalence. Consequently, they are hard to interpret with highly unbalanced data.

Good suggestion. Do you know if there is a way to build Kappa and AUC into the loss function and not just to be used as a metric?

Something I tried with unbalanced classes which seemed to be working at least for this use case .https://towardsdatascience.com/deep-learning-unbalanced-training-data-solve-it-like-this-6c528e9efea6?source=linkShare-8db6081dee6c-1540530648

AUC is not a differentiable function so there is no direct way to optimize this function using backpropagation.
I found this for kappa that looks interesting but I don`t have access to the complete paper :

Usually optimizing cross-entropy using class weight or data rebalance is enough to indirectly optimize kappa or AUC.

1 Like

This is implementation of the loss closely related to auc_roc_metrics