Hello guys!
I have an imbalanced dataset and I need to use class weights in the loss function. What is the correct way to use class weights in fastai library?
Learner.crit = your loss_func()
Some loss functions take class weights as input, eg torch NLLLoss, CrossEntropyLoss: parameter weight=tensor of weights.
Learner.crit = CrossEntropyLoss(weight=[…])
for further details see pytorch source:
‘https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py’
It works properly, thank so much!
Hey,
I am a beginner
I am getting an error doing this during learn.fit().
I did this using the cat and dog example code. Do I have to set anything else?
The Error is the following:
in ()
----> 1 learn.fit(1e-3, 4, cycle_len=5,cycle_mult=4)
~/fastai/courses/dl1/fastai/learner.py in fit(self, lrs, n_cycle, wds, **kwargs)
225 self.sched = None
226 layer_opt = self.get_layer_opt(lrs, wds)
–> 227 return self.fit_gen(self.model, self.data, layer_opt, n_cycle, **kwargs)
228
229 def warm_up(self, lr, wds=None):
~/fastai/courses/dl1/fastai/learner.py in fit_gen(self, model, data, layer_opt, n_cycle, cycle_len, cycle_mult, cycle_save_name, best_save_name, use_clr, metrics, callbacks, use_wd_sched, norm_wds, wds_sched_mult, **kwargs)
172 n_epoch = sum_geom(cycle_len if cycle_len else 1, cycle_mult, n_cycle)
173 return fit(model, data, n_epoch, layer_opt.opt, self.crit,
–> 174 metrics=metrics, callbacks=callbacks, reg_fn=self.reg_fn, clip=self.clip, **kwargs)
175
176 def get_layer_groups(self): return self.models.get_layer_groups()
~/fastai/courses/dl1/fastai/model.py in fit(model, data, epochs, opt, crit, metrics, callbacks, stepper, kwargs)
94 batch_num += 1
95 for cb in callbacks: cb.on_batch_begin()
—> 96 loss = stepper.step(V(x),V(y), epoch)
97 avg_loss = avg_loss * avg_mom + loss * (1-avg_mom)
98 debias_loss = avg_loss / (1 - avg_mombatch_num)
~/fastai/courses/dl1/fastai/model.py in step(self, xs, y, epoch)
41 if isinstance(output,tuple): output,*xtra = output
42 self.opt.zero_grad()
—> 43 loss = raw_loss = self.crit(output, y)
44 if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
45 loss.backward()
~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
323 for hook in self._forward_pre_hooks.values():
324 hook(self, input)
–> 325 result = self.forward(*input, **kwargs)
326 for hook in self._forward_hooks.values():
327 hook_result = hook(self, input, result)
~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
599 _assert_no_grad(target)
600 return F.cross_entropy(input, target, self.weight, self.size_average,
–> 601 self.ignore_index, self.reduce)
602
603
~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce)
1138 >>> loss.backward()
1139 “”"
-> 1140 return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index, reduce)
1141
1142
~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce)
1047 weight = Variable(weight)
1048 if dim == 2:
-> 1049 return torch._C._nn.nll_loss(input, target, weight, size_average, ignore_index, reduce)
1050 elif dim == 4:
1051 return torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)
RuntimeError: nll_loss(): argument ‘weight’ (position 3) must be Variable, not list
I am not sure if this should be posted in a different post, but
has anyone tried to do a benchmark on using batch sampling (drawing the same amount of images from each class each batch) vs class weights?
Say you have 100 images of label “0” and 10 images of label “1”.
Approach a) would be using 5 images of “0” and 5 images of “1” in each batch.
Approach b) would be using approx 9 images of “0” with weight 1 and 1 image of “1” with weight 10.
I guess with data augmentation approach a) seems to be more promising… anyone who has experience here?
Edit:
I maybe should add that in my case class “0” looks very much alike, whereas “1” can be quite diverse.
Therefore my approach is to use data augmentation only for class “1”.
hi,
I’m trying to do the same thing. Would you mind sharing some code? Thank you so much in advance!
happy holidays!
from torch import nn
weights = [0.4, 1]
class_weights=torch.FloatTensor(weights).cuda()
learn.crit = nn.CrossEntropyLoss(weight=class_weights)
How do you figure out how much weights should be assigned to each of your classes? Is weight_classX = 1/(number of examples in classX)
a good approach?
I guess there is no perfect answer for that. Let’s say you have 2 classes split: 10% / 90%. I would assume that 90% of the loss comes from class2. So to balance that out, class1 shoud have weight 9 vs 1 for class 2:
10% * 9 = 90% * 1.
That would be my understanding of that. You have to just experiment yourself what would be the best solution in your case.
What I’ve experienced and read is that a better alternative to solve class imbalances would be “oversampling”.
You don’t apply class weights on the loss, but adjust dataloader accordingly to sample with class weights. In this case I believe you would like to have class weights = 50% and 50%. So they will be sampled with equal probability.
I do believe it is superior method to tackle class imbalance problem. However you always have to experiment yourself and see what works best for you and also try to understand why.
Check out these links from pytorch docs:
https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler
https://pytorch.org/docs/stable/_modules/torch/utils/data/sampler.html#WeightedRandomSampler
I found a nice implementation and examples here:
Hi,
I am also trying to train imbalanced classes, by implementing this class. Having problems integrating it into the current version of Fast.ai. You can check out the code and discussion here ; Paper on Imbalanced classes
Hey,
do you think working with a weighted loss function is the right approach if I want to manually imbalance classes?
Example:
I have a two class image classification problem, where I cannot miss an image of Class 1 (anomaly), while having images of Class 2 wrongly classified as Class 1 is not that big of a problem.
Both Classes have little data (~100 images each).
Also any tips on how to adjust the weights for this kind of problem?
Or do you have any other ideas?
Hey,
I had the same problem. For me it actually helped to change the weights of the loss function. I also added a lot of wrongly classified data (so much more than the correct labels). I did it because the wrongly corrected classes always looked a lot diffrent. So you could actually put them a lot off diffrent classes. So to learn this actually group of classes you actually need a lot of samples. In my case increasing the wrongly classified images actually was the key to success.
Hey Kathi,
thanks for your comment. By “adding a lot of wrongly classified data” did you get “more” data by oversampling or was there more data available in your case?
ther was more data available
hey @Kathi
after using the weights the train loss seems to increase instead of decreasing can u help?
Hi @Kathi , if you had 2 target classes, say [cat, dog]
, how do you know which of those weights are assigned to which of the target classes? For example, if you wanted to assign the weight 0.4
to dog
, how do you make sure that it is not being assigned to cat
and vice-versa?
Generally you’d look at what idx correlates to what class. In fastai you can do this by looking at dls.vocab
, and that order corresponds with indices 0-n
Hi Mario!
So, you get two batches from 10 (5 + 5) images of label “1”?
What are you doing with the 90 images left labeled “0”?
I’ve recently seen a paper using this approach but didn’t tried it myself.
Hi Kathi! Why did you add “wrongly classified data”?
Hi Simon!
Given your example I supposed you’re trying to squish a little dataset of ~200 images. If that’s the case not only using a weighted loss function will help or maybe it even won’t work at all. You could also consider, first of all, transfer learning, also look at oversampling and data augmentation techniques.