Correcting Class imbalance for NLP

(Leon Dummer) #1

Hello everyone,

so I have a dataset with a decent class imbalance on 4 classes of (18%, 18%, 14%, 49%), where the first 3 are concrete emotions and the fourth one is a class for “other”. Usually I would go ahead and put the apprpriate weights in a loss function to fix the imbalance, but I don’t see any possibility to do so with
Am I overlooking something or does it even make sense in that case to use weights to fix the imbalance?

(Even Oldridge) #2

You can do this directly with a custom loss function that wraps cross entropy. Wouldn’t be a bad feature to add as it’s a common task, but it’s a little tricky as you generally want to train on this loss but evaluate your validation/test on the unbalanced set so you’d want to add CE as a metric. The other option is over or undersampling your data, which would mean a class aware dataloader.

(Kyle Nesgood) #3

@EinAeffchen - if you do this, would you mind posting your code on how you solve it? I’d be interested - my programming fu isn’t up to snuff yet and I’d love to see an example.

(Leon Dummer) #4

@knesgood if I get it done, I’ll post the code here. But for now I have no idea how to do it either.

(Leon Dummer) #5

So it seems to be actually way easier than expected. With a tiny digging through the code I saw that the RNN_Learner overwrites it’s super classes “Learner”'s _get_crit function with a return of the Pytorch F.cross_entropy function. That already accepts weights, so you can just pass your calculated weights as

loss_weights = torch.FloatTensor(trn_weights).cuda()
learn.crit = partial(F.cross_entropy, weight=loss_weights)

I calculated my weights simply with that code:

trn_labelcounts = df_trn.groupby(["labels"]).size()
val_labelcounts = df_val.groupby(["labels"]).size()
trn_label_sum = len(df_trn["labels"])
val_label_sum = len(df_val["labels"])
trn_weights = [count/trn_label_sum for count in trn_labelcounts]
val_weights = [count/val_label_sum for count in val_labelcounts]
trn_weights, val_weights

To check the correct parsing of your weights you can simply print them:


which should return something like:

functools.partial(<function cross_entropy at 0x00000282813B3268>, weight=tensor([0.1815, 0.1816, 0.1414, 0.4956], device='cuda:0'))

If you have any trouble don’t hesitate to write me @knesgood