Handling imbalanced classes (`crit` doesn't exist anymore?)

Hi,
I’m working with an imbalanced dataset of images for a binary classification task. I already read few questions and posts and I want to use pytorch's cross entropy with the weight parameter as stated in this question from 2018 on the forum:

weights = [w1, w2, w3, ...]
class_weights = torch.FloatTensor(weights)`
learn.crit = nn.CrossEntropyLoss(weight=class_weights)

But checking the Learner class, I can see that the attribute crit doesn’t exist. I also tried reading the commit history of basic_train.py with no luck.

My question is: Is the following snippet a correct way for using this weighted cross entropy?

class_weights=torch.FloatTensor([4, 1]).cuda()
loss = nn.CrossEntropyLoss(weight=class_weights)
model = create_cnn(data_bunch, models.resnet34, metrics=error_rate, loss_func=loss)

Thanks in advance!

I have not tested this with create_cnn but with other learners you create the learner' then setlearner.loss_func.func` equal to whatever you would like. for example,if you want to weight class 1 16x class 0:

learn.loss_func.func = nn.CrossEntropyLoss(weight=tensor([1.,16.]))
5 Likes

I have tested this, and works in fastai v1:

learn.loss_func.func = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.2,0.8]).cuda())

4 Likes

Thanks @imanol and @bfarzin for the suggestion. I tried it and it did work. However, ClassificationInterpretation.from_learner(learner_obj) doesn’t seem to work after applying this change to the learner_obj. It throws an error: Expected object of backend CPU but got backend CUDA for argument #3 'weight' from deep within the pytorch lib.

After going through the traces, I suspected there is something wrong with setting the weight property to the learner object (Total guess, I am really new to this library. Will update once I find it). So, I tried to pass the loss_func as an arg to learner constructor and it worked!

So, here is what solved the issue for me:

learn = cnn_learner(data, models.resnet34, metrics=[error_rate, accuracy], 
                    loss_func=nn.CrossEntropyLoss(weight=torch.FloatTensor([5., 1.]).cuda()))

Here is what I tried before that cased an issue:

learn = cnn_learner(data, models.resnet34, metrics=[error_rate, accuracy])
learn.loss_func.func = nn.CrossEntropyLoss(weight=torch.FloatTensor([5., 1.]).cuda())
...<some_learning>
show_results(learn)

And here is an error trace (is quite a big), if someone wants to work on it :

<ipython-input-162-538d0a4da532> in show_results(learn)
      1 def show_results(learn):
      2     learn.recorder.plot_losses()
----> 3     interp = ClassificationInterpretation.from_learner(learn)
      4     interp.plot_top_losses(9, figsize=(15,11), heatmap= False)
      5     print(interp.most_confused(min_val=1))

/opt/conda/lib/python3.6/site-packages/fastai/vision/learner.py in _cl_int_from_learner(cls, learn, ds_type, tta)
    126 def _cl_int_from_learner(cls, learn:Learner, ds_type:DatasetType=DatasetType.Valid, tta=False):
    127     "Create an instance of `ClassificationInterpretation`. `tta` indicates if we want to use Test Time Augmentation."
--> 128     preds = learn.TTA(ds_type=ds_type, with_loss=True) if tta else learn.get_preds(ds_type=ds_type, with_loss=True)
    129     return cls(learn, *preds, ds_type=ds_type)
    130 

/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in get_preds(self, ds_type, with_loss, n_batch, pbar)
    334         lf = self.loss_func if with_loss else None
    335         return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
--> 336                          activ=_loss_func2activ(self.loss_func), loss_func=lf, n_batch=n_batch, pbar=pbar)
    337 
    338     def pred_batch(self, ds_type:DatasetType=DatasetType.Valid, batch:Tuple=None, reconstruct:bool=False, with_dropout:bool=False) -> List[Tensor]:

/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in get_preds(model, dl, pbar, cb_handler, activ, loss_func, n_batch)
     44            zip(*validate(model, dl, cb_handler=cb_handler, pbar=pbar, average=False, n_batch=n_batch))]
     45     if loss_func is not None:
---> 46         with NoneReduceOnCPU(loss_func) as lf: res.append(lf(res[0], res[1]))
     47     if activ is not None: res[0] = activ(res[0])
     48     return res

/opt/conda/lib/python3.6/site-packages/fastai/layers.py in __call__(self, input, target, **kwargs)
    235         if self.floatify: target = target.float()
    236         input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
--> 237         return self.func.__call__(input, target.view(-1), **kwargs)
    238 
    239 def CrossEntropyFlat(*args, axis:int=-1, **kwargs):

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    940     def forward(self, input, target):
    941         return F.cross_entropy(input, target, weight=self.weight,
--> 942                                ignore_index=self.ignore_index, reduction=self.reduction)
    943 
    944 

/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2054     if size_average is not None or reduce is not None:
   2055         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2056     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2057 
   2058 

/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1869                          .format(input.size(0), target.size(0)))
   1870     if dim == 2:
-> 1871         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1872     elif dim == 4:
   1873         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: Expected object of backend CPU but got backend CUDA for argument #3 'weight'

Your error tells you the problem. You have manually set the weight to go on the GPU with .cuda() in your statement:

and inside basic_train.py the call is to NoneReduceOnCPU(loss_func) and that produces a mismatch between CPU and GPU variables.

1 Like

Thanks @bfarzin, Actually, initially I tried the way you suggested i.e. to pass CPU tensor to the loss function like this:

learn2 = cnn_learner(data2, models.resnet50, metrics=[error_rate, accuracy],
                      model_dir="/tmp/model/")
learn2.loss_func.func = nn.CrossEntropyLoss(weight=torch.FloatTensor([5., 1.]))

But I got a similar but inverse error during lr_find() i.e. RuntimeError: Expected object of backend CUDA but got backend CPU for argument #3 'weight'. Which is the reason why I passed on GPU tensor instead. Here is a full error trace:

RuntimeError                              Traceback (most recent call last)
<ipython-input-11-083e59c4fff8> in <module>
----> 1 learn2.lr_find()
      2 learn2.recorder.plot()

/opt/conda/lib/python3.6/site-packages/fastai/train.py in lr_find(learn, start_lr, end_lr, num_it, stop_div, wd)
     30     cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)
     31     epochs = int(np.ceil(num_it/len(learn.data.train_dl)))
---> 32     learn.fit(epochs, start_lr, callbacks=[cb], wd=wd)
     33 
     34 def to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None,

/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
    198         callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
    199         if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
--> 200         fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
    201 
    202     def create_opt(self, lr:Floats, wd:Floats=0.)->None:

/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
     99             for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
    100                 xb, yb = cb_handler.on_batch_begin(xb, yb)
--> 101                 loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
    102                 if cb_handler.on_batch_end(loss): break
    103 

/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     28 
     29     if not loss_func: return to_detach(out), yb[0].detach()
---> 30     loss = loss_func(out, *yb)
     31 
     32     if opt is not None:

/opt/conda/lib/python3.6/site-packages/fastai/layers.py in __call__(self, input, target, **kwargs)
    235         if self.floatify: target = target.float()
    236         input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
--> 237         return self.func.__call__(input, target.view(-1), **kwargs)
    238 
    239 def CrossEntropyFlat(*args, axis:int=-1, **kwargs):

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    940     def forward(self, input, target):
    941         return F.cross_entropy(input, target, weight=self.weight,
--> 942                                ignore_index=self.ignore_index, reduction=self.reduction)
    943 
    944 

/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2054     if size_average is not None or reduce is not None:
   2055         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2056     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2057 
   2058 

/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1869                          .format(input.size(0), target.size(0)))
   1870     if dim == 2:
-> 1871         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1872     elif dim == 4:
   1873         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: Expected object of backend CUDA but got backend CPU for argument #3 'weight'
1 Like

Alright so, I think I have tracked down the problem (turned out to be a minor mistake).
You suggested:

I went through the Learner and DataBunch constructor to find out how it assigns a loss function which turned out to be a property called loss_func defaulting to F.nll_loss. Therefore, it makes sense to assign this new callable loss function instance to loss_func directly instead of loss_func.func.
And so, this is what I tried and worked perfectly fine. Please point out if I am doing something wrong.

learn.loss_func = nn.CrossEntropyLoss(weight=tensor([1.,16.]).cuda())

And as I said before, passing it to the constructor worked as well.

Thanks for your suggestions! I’m liking it here :slight_smile:

2 Likes

It’d be nice to check there is an oversampling callback implemented by awesome person!
Usually people say oversampling or undersampling methods are better than changing loss weight.
It’s just implemented a week ago

2 Likes

your right! It was my error to add the extra .func Glad you sorted it out!!

1 Like

Awesome, I’ll try to use it to see how it fares against the skewed loss weight. Thanks for this!

@Soo thanks for linking my callback over here. I was actually going to do so, and I saw you already did it!
I hope this callback will be helpful to others.

The code in that post is slightly outdated, so if you want to copy and paste the callback into your program, use the code from the file:

Let me know if you have any questions!

3 Likes

Oh, it was outdated.
I will definitely use it because my dataset is also imbalanced .
Thanks a lot!

The newer version (Oversampling) is now in fastai library and accepts weights too. It would be great if anyone add this to the library’s document.

2 Likes

Can we use this for multiple label images? In Image segmentation projects?