Handling Imbalanced Targets in Tabular Binary Classification

Hey, I just wanted to see if anyone has found a preferred way to handle imbalanced binary classification tasks for tabular data in fastai V1. After perusing the docs for a while I don’t see anything built in to handle it. It’d be nice if I could pass in SMOTE or something similar as a transform to the TabularDataBunch. Before I burn too much more time on this I just wanted to check and see if anyone else has figured this out yet.

The problem is that SMOTE can’t handle categorical data unless they are converted to ints, but I don’t want to convert categorical data to ints in order to run smote only to convert back to cats as strings in order to use fastai native cat handling. So even if there’s no native way to handle this, if someone has a good workflow they would suggest it’d be great.

Thanks

1 Like

On second thought, since SMOTE can’t handle missing values either in addition to not being able to handle categoricals, at that point it probably makes sense to run it after the fastai TabularDataBunch call.

So maybe now the question is figuring out a graceful way to access and transform the data within the databunch?

1 Like

You can customize TabularDataset to create a TfmTabularDataset a bit like we have in Computer vision. That would be the most elegant way to handle this.

1 Like

@sgugger thanks for the reply ! (and all your hard work on the library)

So I spent last night trying to figure out how to do this and am a little confused, any tips at all on how to create a custom TfmTabularDataset for me? I’ll check out the Computer Vision example to see if I can transfer that understanding to tabular.

My confusion is if I access the df from within the databunch by data.train_ds.df the data appears unchanged. I’m calling, tfms = [FillMissing, Categorify] so I’m assuming those transforms happen when a batch is called.
In this case i don’t think applying SMOTE to just a batch is appropriate, but maybe I’m wrong?

Hopefully we can figure out something that works because I think native handling of imbalanced tabular datasets is important for the library given how common it is in practice. I did notice SMOTE had Balanced Batch Generator for Keras so maybe balancing by batch is appropriate and we can model after that keras API

No, those transforms are called on the dataframe before creating the internal training set, once and for all before you start doing anything. Since you want to apply something that go after this if I understood correctly, you should create your own dataset over TabularDataset.

Oh, I didn’t realize the changes were applied already because the cats were still in string format when I access them through data.train_ds.df and unfortunately SMOTE can’t handle cats still shown as strings rather than codes.

I’m not entirely clear on how to create a dataset over TabularDataset but I’ll give it a shot and report back if I figure anything out.

I just filtered the rare examples and overloaded them in a new data frame. This new data frame would be the source data for the data bunch.

There’s prettier ways to do it but I found this example first :wink:

dfa = df[df[‘TARGET’] < 18]
dfb = df[(18 < df[‘TARGET’]) & (df[‘TARGET’] < 48)]
frames = [df, dfa,dfb, dfb, dfb, dfa, dfa, dfa, dfa, dfa, dfa, dfa, dfa,dfa, dfa, dfa, dfa, dfa, dfa, dfa, dfa, dfa,dfa, dfa, dfa, dfa, dfa, dfa, dfa, dfa, dfa,dfa]
df = pd.concat(frames)

1 Like

While using Vanilla PyTorch, I was using the following code.

    freq = np.histogram(y_batch)[0]
    len_batch = len(y_batch)
    w = torch.FloatTensor([freq[-1]/len_batch, freq[0]/len_batch])
    w_ = torch.Tensor([[w[int(j[0])]] for j in y_batch])
    if cuda:
        w = w.cuda()
        w_ = w_.cuda()
    loss_fn.weight = w_ 
    loss_fn.pos_weight = w

@sgugger what do you think about this approach? The code can be further optimized.

To avoid the use of cuda, you should use y_batch.new() (or equivalent) when creating new tensors. It will create them on the same device as y_batch automatically.

Thanks for pointing it out :slight_smile: I will take care of that.

I was wondering if this can be part of the fastai library itself, as it can handle imbalanced data automatically.

What do you think?

We could add it has an option. As Jeremy said in the lesson, he isn’t very fond of changing things for unbalanced class, so this wouldn’t be the default.
Maybe you can write a function that takes a data object and a loss function then returns the loss function with those weights? I feel the weights would have to be computed over the whole training set to be relevant.

1 Like

Okay.

I feel the weights would have to be computed over the whole training set to be relevant.

In my case, I was calculating weights per batch since we are calculating loss per batch and averaging it. Let me check if I can compute weight over the whole training set.

Yea I recall him saying that in the most recent p1V3 lecture, but that hasn’t been my experience in Tabular datasets with the binary target 1% or 5% of the population. In many cases without doing some balancing I’ve found the models fail to learn regardless of learning rate or architecture.

4 Likes

My fraud models were vastly improved in bringing the target up from <1/1000. The question to Jeremy came in the context of Vision, so perhaps those models are better able to deal with rare targets.

1 Like

When it comes to vision, I wouldn’t even dare to disagree with Jeremy on that, because of data augmentation, having the ability to flip the photos, zoom, and everything else. The augmentation to the rare classes creates more of the sparse class in the data set.

I haven’t found anything on tabular data sets, but things like SMOTE, but it has an issue with extremely sparse classes. Then oversampling/upsampling the rare classes tends to help a little, but we run into over fitting the training set because of the copies. If any of you have any good articles on data augmentation of tabular data sets or more on the oversampling and wouldn’t mind sharing them I would greatly appreciate it.

This is a great article, but shows how you can fix unbalanced classes for vision.

1 Like

Hi guys, has anyone found a solution for imbalanced targets in Tabular data? There’s plenty examples for vision.

Bonus points for code sample/link to a notebook where it has been implemented :pray:t6:

@ikey001 see here: Oversampling Callback

1 Like

Thanks @muellerzr