Oversampling Callback

Hello all,

I wrote my first callback, meant for performing oversampling in fastai when training on imbalanced datasets. I started using oversampling on a severely imbalanced medical dataset, and preliminary results seem to show that oversampling significantly helps training.

Here is the Callback:

from torch.utils.data.sampler import WeightedRandomSampler
class OverSamplingCallback(Callback):
    def __init__(self):
        self.labels = learn.data.train_dl.dataset.y.items
        _, counts = np.unique(self.labels,return_counts=True)
        self.weights = torch.DoubleTensor((1/counts)[self.labels])
        
    def on_train_begin(self, **kwargs):
        learn.data.train_dl.dl.batch_sampler = BatchSampler(WeightedRandomSampler(weights,len(data.train_dl.dataset)), data.train_dl.batch_size,False)        

I have a kernel that highlights the use of this callback for the MNIST dataset.

Please let me know if this is something worth having in the fastai library and if you have any feedback regarding this callback. I will try to create a pull request.

EDIT:

For those visiting this post, the above version is incorrect. The updated version (with addition from @lewfish):

from torch.utils.data.sampler import WeightedRandomSampler

class OverSamplingCallback(LearnerCallback):
    def __init__(self,learn:Learner,weights:torch.Tensor=None):
        super().__init__(learn)
        self.labels = self.learn.data.train_dl.dataset.y.items
        _, counts = np.unique(self.labels,return_counts=True)
        self.weights = (weights if weights is not None else
                        torch.DoubleTensor((1/counts)[self.labels]))
        self.label_counts = np.bincount([self.learn.data.train_dl.dataset.y[i].data for i in range(len(self.learn.data.train_dl.dataset))])
        self.total_len_oversample = int(self.learn.data.c*np.max(self.label_counts))
        
    def on_train_begin(self, **kwargs):
        self.learn.data.train_dl.dl.batch_sampler = BatchSampler(WeightedRandomSampler(self.weights,self.total_len_oversample), self.learn.data.train_dl.batch_size,False)

Note that this is already implemented in the fastai library and is part of fastai v1.0.57

8 Likes

A pull request (my first!) has been made for the callback over here:

Thoughts? @sgugger

2 Likes

I replied on GitHub, thanks a lot!

1 Like

@ilovescience Thanks for this callback. Could it be extended for the multi-label case?

2 Likes

Hi @ilovescience, thanks for the code, it has been the best approach of all the tests I’ve done to improve the balanced accuracy on highly imbalanced data.sets.

Now the bad news, on inference time, your callback needs the train data to be present in the learner databunch to calculate the weights in the constructor.

Here is the gist with the error shown:

I guess that removing the callback on inference time, might solve the issue
but may I suggest to move the weights calc to on_train_begin, something like:
(spoiler alert I have 0 experience hacking with python code :smirk:)

from …torch_core import *
from …basic_data import DataBunch
from …callback import *
from …basic_train import Learner,LearnerCallback
from torch.utils.data.sampler import WeightedRandomSampler

all = [‘OverSamplingCallback’]

class OverSamplingCallback(LearnerCallback):
def init(self,learn:Learner,weights:torch.Tensor=None):
super().init(learn)
self.weights = weights

   def on_train_begin(self, **kwargs):
    self.labels = self.learn.data.train_dl.dataset.y.items
    _, counts = np.unique(self.labels,return_counts=True)
    if self.weights is None: self.weights = torch.DoubleTensor((1/counts)[self.labels]) 
    self.label_counts = np.bincount([self.learn.data.train_dl.dataset.y[i].data for i in range(len(self.learn.data.train_dl.dataset))])
    self.total_len_oversample = int(self.learn.data.c*np.max(self.label_counts))
    self.learn.data.train_dl.dl.batch_sampler = BatchSampler(WeightedRandomSampler(self.weights,self.total_len_oversample), self.learn.data.train_dl.batch_size,False)

I have submited the PR with the fixes and test cases (before & after)

Hoping it will be useful.

2 Likes

@ilovescience Thanks so much for making this.

One problem I’ve found is that in the most extreme case, when there is a class which appears only once in the entire dataset, the single example may end up in the validation set, meaning the length of counts is less than the number of classes, so it cannot be indexed by self.labels.

I think this could be resolved by changing counts to be to same as self.label_counts (np.bincount instead of np.unique), and then each item in self.weights will be 0 if the corresponding count is 0, otherwise 1/count, like this:

self.weights = np.array([0 if i == 0 else 1 / i for i in counts])
self.weights = torch.DoubleTensor((self.weights)[self.labels])

or alternatively

self.weights = np.divide(1, counts, out=np.zeros_like(counts, dtype=np.float64), where=(counts!=0))
self.weights = torch.DoubleTensor((self.weights)[self.labels])

Hi just wanted to make a suggestion.

This seems to work well if you are dealing with a classification problem. Would it be helpful to add some sort of way to specify if you are working on a regression problem where self.labels are floats rather than ints?

For example:

self.weights = (weights if weights is not None else
                        torch.DoubleTensor((1/counts)[self.labels.astype(int)]))

I would recommend considering to replace the line

    _, counts = np.unique(self.labels,return_counts=True)

with

    counts = np.bincount(self.labels, None, self.learn.data.c)

and make sure to “not have any zeroes” in the divisor

    self.weights = torch.DoubleTensor((1/(counts + 1e-8))[self.labels])

because at the moment the Oversampling Callback will fail, in case you have a databunch, with some classes which do not appear in the train dataset.

Could you show a beginner like me, how to use this callback? I could not find info about it in the doc.

Hi, usage could look like this:

#define learner
learn = cnn_learner(data, models.resnet34, metrics = error_rate)

#define callbacks, all which you define will be used for your training
callbacks = [
OverSamplingCallback(learn)
,ShowGraph(learn) #any other callbacks
]
learn.callbacks = callbacks

#perform your training, with all callbacks included
learn.fit_one_cycle(14)

From its description, it creates “a weighted random sampler with weights corresponding to 1/counts of the classes”.
It can lead to overfitting, if you upsample too much.

3 Likes

Thanks @tomassa for the example! Could you please explain what do you mean by upsampling too much? How can we control the amount of upsampling in this callback?
Do you advice to use this only when the dataset is not very highly unbalanced?

Hey, can you please let me know how to test if actually minority class is getting oversampled. I wrote this callback but I think there is something wrong in my callback as it keeps returning same values and also returns unbalance data.

class class_samples(Callback):
    def on_epoch_begin(self, **kwargs):
        self.targs, self.preds,self.disribution = LongTensor([]), LongTensor([]),Tensor([])
        
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        last_output = torch.argmax(F.softmax(last_output, dim=1),dim=1)
        #print(last_output)
        self.preds = torch.cat((self.preds, last_output.cpu()))
        self.targs = torch.cat((self.targs, last_target.cpu()))
    
    def on_epoch_end(self, last_metrics, **kwargs):
        _,self.distribution = torch.unique(self.targs,return_counts=True)
        self.disribution = int(self.distribution[1].item())
        
        return add_metrics(last_metrics,self.distribution)