How to use class weights in loss function for imbalanced dataset

from torch import nn

weights = [0.4, 1]
class_weights=torch.FloatTensor(weights).cuda()
learn.crit = nn.CrossEntropyLoss(weight=class_weights)

4 Likes