Oversampling Callback

@ilovescience Thanks so much for making this.

One problem I’ve found is that in the most extreme case, when there is a class which appears only once in the entire dataset, the single example may end up in the validation set, meaning the length of counts is less than the number of classes, so it cannot be indexed by self.labels.

I think this could be resolved by changing counts to be to same as self.label_counts (np.bincount instead of np.unique), and then each item in self.weights will be 0 if the corresponding count is 0, otherwise 1/count, like this:

self.weights = np.array([0 if i == 0 else 1 / i for i in counts])
self.weights = torch.DoubleTensor((self.weights)[self.labels])

or alternatively

self.weights = np.divide(1, counts, out=np.zeros_like(counts, dtype=np.float64), where=(counts!=0))
self.weights = torch.DoubleTensor((self.weights)[self.labels])
1 Like