Adding weighted sampler

Hi
Is it possible to make use of pytorchs weighted sampler in dataloaders…

1 Like

Hi,
I was trying to do the same. Have you had any success implementing this?

Thanks

Oversampling is implemented in the library as a callback

2 Likes

Thanks - that should work. I am not familiar with using callbacks - any example code on how it can be used would be very helpful.

I figured it out. You can do it as follows:

(make sure you have the latest fastai from the repo first)

from fastai.callbacks import *

cb = OverSamplingCallback(learn)
learn.fit_one_cycle(4,callbacks = cb)

This seems to work. But have to still test it with an unbalanced dataset to check accuracy. I also noticed that it increases the number of epochs needed to get to the same level of accuracy as without using it.

Another way to work it is to do this:

learn = cnn_learner(data,models.resnet50,metrics=[accuracy],callback_fns=[OverSamplingCallback])

I tested it on MNIST over here and showed oversampling improved results on an imbalanced version of the dataset, but of course was still worse than training on the original MNIST.

I think this has more to do with the fact that accuracy is a bad metric for unbalanced datasets. For example, if a dataset is 80% class 1 and 20% class 0, if it predicts class 1 always, accuracy is already up at 80%, but if it is oversampled, then the same approach will only yield 50% accuracy. It might be necessary to use a different metric. For example, you could use an F1 metric.

Thanks - the Oversampling did help in training but problem of improving prediction accuracy for minority classes is not fully solved. So have to work on this - maybe try different metrics as you suggest.

Also just a note : I used your method of calling the Oversampling function but when running LR_find & Learn.recorder.plot()
I get the following warning message :

> UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
>   self.weights = torch.tensor(weights, dtype=torch.double)

hi,
Do you know how does it selects the values based on weights. some how not able to understand
Eg.
list(WeightedRandomSampler([10,5,10,5,10],10, replacement=True))
[0, 0, 0, 4, 4, 1, 2, 0, 4, 1]
why do we have more zeros and fours out here

I agree! The model is just random guessing and the “worst” it could performs is 80 %.