Handling class imbalance in Deep Learning models

We are currently working on an image recognition problem in which we try to recognize solar panels on roofs. Our dataset consists of aerial images of roofs, for each of which we have a label which indicates whether the roof contains a solar panel or not. This dataset was manually validated, so the noise on it should be minimal. The major difficulty that we are facing, however, is the class imbalance: roughly 1 out of 20 roofs contains a solar panel.

We are using the following datasets:

  • (balanced) training set: 3150 positive, 3150 negative examples
  • (balanced) validation set: 784 positive, 784 negative examples
  • (unbalanced) test set: 138 positive, 2862 negative examples

These are our results on the test set:

Pred No-Panel Pred Panel Precision
Actual No-Panel 2615 247 91.37%
Actual Panel 15 122 89.05%
Recall 99.43% 33.06% Total acc: 91.26%

Although these results certainly are not bad, the recall for panel drops from 86% (validation error) to 33% (test error) due to the class imbalance. Is there any way to improve on that?

We have produced these results using a PyTorch model based on resnet18 with the following specification:

  • Kaiming initialisation of the weights
  • Freezing of all feature layers of resnet18, with the exception of batchnorm layers
  • Custom classifier with the following specification:
    • AdaptiveConcatPool2d(),
    • Flatten(),
    • nn.BatchNorm1d(n_feat),
    • nn.Dropout(dropout/2),
    • nn.Linear(n_feat, n_filter),
    • nn.ReLU(inplace=True)
    • nn.BatchNorm1d(n_filter),
    • nn.Dropout(dropout),
    • nn.Linear(n_filter, n_class)

Initially, we trained this model with AdamW and cross entropy loss resulting in the outcome shown above. In order to improve this, we have tried a number of other strategies, but with no real changes in the results:

  • Varying learning rates: after every X epochs we lower the learning rate
  • Increase decrease of dropouts
  • Use a proxy for AUC as loss function to optimize the AUC of the model, which should be more robust to class-imbalance (we added a batchnorm and softmax layer to the classifier as this metric requires probabilities). We used the Wilcoxon-Mann-Whitney U-Statistic as described here: https://blog.revolutionanalytics.com/2017/03/auc-meets-u-stat.html

We could also:

  • Use Ensembles
  • Implement test time augmentation
  • Label more data

But we want to focus on optimizing the performance of a single model first. So the question remains: are there other known methods to deal with class imbalance that could help us achieve better results?

1 Like

You can try to oversample the underrepresented class, i.e. copy the data from the underrepresented class.
With increased data augmentation it should be able to avoid overfitting to the data.
(I guess the easiest approach to this is with the sampler below.)

Other stuff I found in a google search (and I have not tried out myself yet):

If you are able to improve your model with these approaches I would be happy to hear about your experience with them! :slight_smile:

Kind regdards
Michael

3 Likes

Jeremy has commented that he doesn’t worry about imbalance when it comes to CV, that the needed features will be learned. If you consider a multiclass model, where you are classifying one of 30 dog breeds, any one breed is relatively rare, so you might want to train on real world conditions.

You might also try to break your panels into multiple classes, in case they fall into multiple visual types.

2 Likes

How does it do with a fast.ai implementation? Reading (between?) the lines it seems this is raw pytorch. The #1 thing fast.ai gives me is power to experiment quickly. Switch in and out different models/architectures (why resnet18?), discriminative learning rate schedules, image sizes, play with dropout and weight decay, try various loss functions, visualise results rapidly, etc. You’re not keen on TTA yet, but it only takes seconds to see what effect it may have.

I’m not so sure I’d jump so readily into it being a class imbalance issue.

How many pixels are the PV arrays in your images aka what resolution aerial images are you using?

2 Likes

Thank you for your answers. I am becoming more and more convinced that the problem itself is not necessarily caused by the class imbalance. However, when taking the class imbalance into account (while analyzing the results), the class imbalance has a negative impact on the final results.

To answer a couple of questions:

  • Resolution: 128x128 pixels. 1 pixel represents 10cm at the ground level.
  • We started out by using the fast.ai implementation. It’s a fantastic library for fast prototyping fore sure. However, to gain a deeper insight into what exactly is happening in the fast.ai library, we did a deep dive in the code and slimmed it down to what (we think) is relevant to us.

Now, we are going to take a second look at the problem with the insights that we have gained :slight_smile: