Hello all,
I wrote my first callback, meant for performing oversampling in fastai when training on imbalanced datasets. I started using oversampling on a severely imbalanced medical dataset, and preliminary results seem to show that oversampling significantly helps training.
Here is the Callback:
from torch.utils.data.sampler import WeightedRandomSampler
class OverSamplingCallback(Callback):
def __init__(self):
self.labels = learn.data.train_dl.dataset.y.items
_, counts = np.unique(self.labels,return_counts=True)
self.weights = torch.DoubleTensor((1/counts)[self.labels])
def on_train_begin(self, **kwargs):
learn.data.train_dl.dl.batch_sampler = BatchSampler(WeightedRandomSampler(weights,len(data.train_dl.dataset)), data.train_dl.batch_size,False)
I have a kernel that highlights the use of this callback for the MNIST dataset.
Please let me know if this is something worth having in the fastai library and if you have any feedback regarding this callback. I will try to create a pull request.
EDIT:
For those visiting this post, the above version is incorrect. The updated version (with addition from @lewfish):
from torch.utils.data.sampler import WeightedRandomSampler
class OverSamplingCallback(LearnerCallback):
def __init__(self,learn:Learner,weights:torch.Tensor=None):
super().__init__(learn)
self.labels = self.learn.data.train_dl.dataset.y.items
_, counts = np.unique(self.labels,return_counts=True)
self.weights = (weights if weights is not None else
torch.DoubleTensor((1/counts)[self.labels]))
self.label_counts = np.bincount([self.learn.data.train_dl.dataset.y[i].data for i in range(len(self.learn.data.train_dl.dataset))])
self.total_len_oversample = int(self.learn.data.c*np.max(self.label_counts))
def on_train_begin(self, **kwargs):
self.learn.data.train_dl.dl.batch_sampler = BatchSampler(WeightedRandomSampler(self.weights,self.total_len_oversample), self.learn.data.train_dl.batch_size,False)
Note that this is already implemented in the fastai library and is part of fastai v1.0.57