 # How to use class weights in loss function for imbalanced dataset

(Aykut Çayır) #1

Hello guys!
I have an imbalanced dataset and I need to use class weights in the loss function. What is the correct way to use class weights in fastai library?

0 Likes

(urmas pitsi) #2

Learner.crit = your loss_func()

Some loss functions take class weights as input, eg torch NLLLoss, CrossEntropyLoss: parameter weight=tensor of weights.

Learner.crit = CrossEntropyLoss(weight=[…])

for further details see pytorch source:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py

5 Likes

(Aykut Çayır) #3

It works properly, thank so much!

0 Likes

(Katharina) #4

Hey,
I am a beginner I am getting an error doing this during learn.fit().
I did this using the cat and dog example code. Do I have to set anything else?

The Error is the following:

in ()
----> 1 learn.fit(1e-3, 4, cycle_len=5,cycle_mult=4)

~/fastai/courses/dl1/fastai/learner.py in fit(self, lrs, n_cycle, wds, **kwargs)
225 self.sched = None
226 layer_opt = self.get_layer_opt(lrs, wds)
–> 227 return self.fit_gen(self.model, self.data, layer_opt, n_cycle, **kwargs)
228
229 def warm_up(self, lr, wds=None):

~/fastai/courses/dl1/fastai/learner.py in fit_gen(self, model, data, layer_opt, n_cycle, cycle_len, cycle_mult, cycle_save_name, best_save_name, use_clr, metrics, callbacks, use_wd_sched, norm_wds, wds_sched_mult, **kwargs)
172 n_epoch = sum_geom(cycle_len if cycle_len else 1, cycle_mult, n_cycle)
173 return fit(model, data, n_epoch, layer_opt.opt, self.crit,
–> 174 metrics=metrics, callbacks=callbacks, reg_fn=self.reg_fn, clip=self.clip, **kwargs)
175
176 def get_layer_groups(self): return self.models.get_layer_groups()

~/fastai/courses/dl1/fastai/model.py in fit(model, data, epochs, opt, crit, metrics, callbacks, stepper, kwargs)
94 batch_num += 1
95 for cb in callbacks: cb.on_batch_begin()
—> 96 loss = stepper.step(V(x),V(y), epoch)
97 avg_loss = avg_loss * avg_mom + loss * (1-avg_mom)
98 debias_loss = avg_loss / (1 - avg_mom
batch_num)

~/fastai/courses/dl1/fastai/model.py in step(self, xs, y, epoch)
41 if isinstance(output,tuple): output,*xtra = output
42 self.opt.zero_grad()
—> 43 loss = raw_loss = self.crit(output, y)
44 if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
45 loss.backward()

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
323 for hook in self._forward_pre_hooks.values():
324 hook(self, input)
–> 325 result = self.forward(*input, **kwargs)
326 for hook in self._forward_hooks.values():
327 hook_result = hook(self, input, result)

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
599 _assert_no_grad(target)
600 return F.cross_entropy(input, target, self.weight, self.size_average,
–> 601 self.ignore_index, self.reduce)
602
603

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce)
1138 >>> loss.backward()
1139 “”"
-> 1140 return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index, reduce)
1141
1142

~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce)
1047 weight = Variable(weight)
1048 if dim == 2:
-> 1049 return torch._C._nn.nll_loss(input, target, weight, size_average, ignore_index, reduce)
1050 elif dim == 4:
1051 return torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)

RuntimeError: nll_loss(): argument ‘weight’ (position 3) must be Variable, not list

0 Likes

(Mario) #5

I am not sure if this should be posted in a different post, but

has anyone tried to do a benchmark on using batch sampling (drawing the same amount of images from each class each batch) vs class weights?

Say you have 100 images of label “0” and 10 images of label “1”.

Approach a) would be using 5 images of “0” and 5 images of “1” in each batch.
Approach b) would be using approx 9 images of “0” with weight 1 and 1 image of “1” with weight 10.

I guess with data augmentation approach a) seems to be more promising… anyone who has experience here?

Edit:
I maybe should add that in my case class “0” looks very much alike, whereas “1” can be quite diverse.
Therefore my approach is to use data augmentation only for class “1”.

0 Likes

#6

hi,

I’m trying to do the same thing. Would you mind sharing some code? Thank you so much in advance!

happy holidays! 0 Likes

(Katharina) #7

from torch import nn

weights = [0.4, 1]
class_weights=torch.FloatTensor(weights).cuda()
learn.crit = nn.CrossEntropyLoss(weight=class_weights)

1 Like

Weighting the cross-entropy loss function for binary classification
(Bikash Gyawali) #8

How do you figure out how much weights should be assigned to each of your classes? Is `weight_classX = 1/(number of examples in classX)` a good approach?

0 Likes

(urmas pitsi) #9

I guess there is no perfect answer for that. Let’s say you have 2 classes split: 10% / 90%. I would assume that 90% of the loss comes from class2. So to balance that out, class1 shoud have weight 9 vs 1 for class 2:
10% * 9 = 90% * 1.
That would be my understanding of that. You have to just experiment yourself what would be the best solution in your case.

What I’ve experienced and read is that a better alternative to solve class imbalances would be “oversampling”.
You don’t apply class weights on the loss, but adjust dataloader accordingly to sample with class weights. In this case I believe you would like to have class weights = 50% and 50%. So they will be sampled with equal probability.

I do believe it is superior method to tackle class imbalance problem. However you always have to experiment yourself and see what works best for you and also try to understand why.

Check out these links from pytorch docs:

I found a nice implementation and examples here:

0 Likes

#10

Hi,

I am also trying to train imbalanced classes, by implementing this class. Having problems integrating it into the current version of Fast.ai. You can check out the code and discussion here ; Paper on Imbalanced classes

0 Likes