Loss function + data imbalance


(Ashwin) #1

I’m currently training a MLP for classification with weighted cross entropy loss layer and evaluating the model based on mean per class accuracy.

Is there a better loss function?


(Yuri Oliveira Galindo) #2

To change the loss function being optimized in an already initialized learner, use
learn.crit = new_loss

The learner class imports loss functions from torch.nn.functional as F, and you can see the available loss functions in https://pytorch.org/docs/master/nn.html

I would recommend F.CrossEntropyLoss(weight=tensor_of_weight_for_each_class))


(Gene Sobolev) #3

I would also try nll_loss, it’s the fastai default for structured data classification: https://pytorch.org/docs/master/nn.html#torch.nn.NLLLoss