I have to create an unbalanced dataset (A=100 samples vs B=10000 samples).
I can’t augment or obtain more of A, so I was thinking of 2 possible ways to go:
- Cutting down B to 100
- Copying A 100 times to match B size.
I think I should go with option 2, but I fear that because of the repetition, doing this will only classify correctly new A samples that are almost a copy of an A sample, not generalizing well.
I also think option 1 is actually worse because it will miss all valuable information about what is not an A.
I don’t care about B though, I only care about classifying correctly A. B is almost infinite and easy to obtain.
What’s the way to go?
One technique I’ve used is to sample each batch 50-50, so if your batch size is 128 you’d pick 64 samples of class A and 64 samples of class B (you can use another ratio here too, of course). This has the same effect as oversampling A but you don’t need to copy any data. You do have to write your own generator (Keras) or batch sampler (PyTorch) to make this work.
The paper “A systematic study of the class imbalance problem in convolutional neural networks” by Buda et al gives a great overview of the different ways you handle handle unbalanced datasets with deep learning.
Sampling will be a better solution for the problem which you mentioned.
But there are other options too which you can refer this blog post
According to Buda’s paper: “As opposed to some classical machine learning models, oversampling does not necessarily cause overfitting of convolutional neural networks.”
I’m familiar with Keras’s ImageDataGenerator but that wont work in my case, also I cant treat B as anomaly detection problem either. So I guess I will trust Buda and go with oversampling. If that gives me overfitting then i will try undersampling carefully, like picking the most representative samples of B and see how it goes.
Is actually a fairly common approach.
I’d try it and see if it improves your model or not, and then report back here so we can know how it goes
One extra aspect to remember is to not rely on accuracy when using an unbalanced dataset: always predicting B would be 99% accurate…