Imbalanced Dataset + OverSamplingCallback

I have a tumor dataset - two folders (one named benign and the other, malignant). It is a imbalanced dataset (almost like 9:1 ratio). I need to do a binary classification.

  1. When I try dataloader.train.show_batch(…), it shows images of only one class (benign, because it is majority class). Is it because, the dataloader picks images at random and hence, the majority class?

  2. In my cnn_learner, I copied the GitHub code for ‘OverSamplingCallback’ callback into a function called (Over_Sampling_Callback) and added print statements to it. I used this function in the callback_fns of the cnn_learner, as follows:

learn = cnn_learner(dls, resnet34, metrics=error_rate,callback_fns= [partial (OverSamplingCallback)]). My print statements are not called. Could it be this function is not getting called?

  1. In the cnn_learner, if I use metrics such as SKLearn’s roc_auc_score, it is telling me there is only one kind of value (which is benign) and I cannot use roc_auc. If point 2 above works, should not I get a balanced datase for which my error_metric should be calculated?

The dataloader picks up a batch of images, and displays them(upto a limit, which you can set as well). If you’re seeing only one kind of label, its because of the skewness of your data. Maybe if you run it again and again, you might see the other label as well. But the truth is that you’ll see only 1 image of the other label for every 9 images of the majority class.

Regarding your Callback, can you please attach a link to the code? I cant help you much without knowing what it is in the first place.

Finally, what do you mean by a ‘balanced’ dataset?

Hi Palaash,

Clear on the first point, thanks.

Point 2, I added this code at link below, to my Jupyter notebook and a ‘print’ statement in the “init” fuction - if this function was called, the print statement should execute (nothing special inside it, just random text):

Post the above code, I created a learner and fine-tuned it:

learn = cnn_learner(dls, resnet34, metrics=error_rate,callback_fns= [partial (OverSamplingCallback)]).

learn.fine_tune runs but I never get the print statement executed. Why is my callback not firing I wonder?

On ‘Balanced Dataset’, I mean, will the ‘OverSamplingCallback’ wil reduce the imbalance when cnn_learner object is fine_tuned.

DB = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items= get_files,
splitter=RandomSplitter(seed=42, valid_pct=0.2),
get_y= parent_label,
batch_tfms=aug_transforms(size=224, min_scale=0.5))

dls = DB.dataloaders(’/Images/Train/’ )
dls.vocab (prints the two values ‘B’, ‘M’)

learn = cnn_learner(dls, resnet34, metrics=[error_rate],callback_fns= OverSamplingCallback)

Maybe try adding your print command under def on_train_begin(self, **kwargs): That way your print statement should be executed at the beginning of training.

You should pass them in as cbs = [OverSamplingCallback], not callback_fns

I think @karthikr wrote callback_fns= [partial(OverSamplingCallback)] in his code(please see the top post). That should have worked anyway, right?

1 Like

Sorry, I thought it was v2 :sweat_smile: Yes that should have worked.

One other thing to try is to call it during your fit (ie pass it as callbacks = partial(OverSamplingCallback)

1 Like

Yes, I think so too.
cnn_learner returns a learner object. And Learner does not take callbacks or callback_functions as a key word arguments. So passing callbacks during fit might work!

1 Like

It does, however IIRC there was something mentioned about tagging callbacks related to fit only in the call to fit, otherwise you’ll run into issues. One example is the EarlyStoppingCallback. It shouldn’t be added to Learner otherwise it will always be present

Thanks @muellerzr, you mention you thought it was v2. Could you share the GitHub codebase repository for v2? I want to explore that

Here is fastai v2 on Github:

and here are the docs:

Hey @muellerzr, is OverSamplingCallback still a thing in fastai v2? I too have an imbalanced dataset and looking for the “fastai way” of oversampling.

Don’t see OverSamplingCallback anywhere in v2 code, though maybe I’m looking in the wrong places?


I’m having the same problem. Did you figure out anything?