FocalLoss with multi class

I found my error. In mySSELoss I was summing the losses over the minibatch rather than averaging them.
in https://pytorch.org/docs/master/nn.html#crossentropyloss
it says: The losses are averaged across observations for each minibatch.
Here is the corrected code

def one_hot_embedding(labels, num_classes):
    return torch.eye(num_classes)[labels.data.cpu()]

class myCCELoss(nn.Module):

    def __init__(self):
        super(myCCELoss, self).__init__()

    def forward(self, input, target):
        y = one_hot_embedding(target, input.size(-1))
        logit = F.softmax(input)
       
        loss = -1 * V(y) * torch.log(logit) # cross entropy loss
        return loss.sum(dim=1).mean()
    
class FocalLoss(nn.Module):

    def __init__(self, gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps

    def forward(self, input, target):
        y = one_hot_embedding(target, input.size(-1))
        logit = F.softmax(input)
        logit = logit.clamp(self.eps, 1. - self.eps)
        
        loss = -1 * V(y) * torch.log(logit) # cross entropy
        loss = loss * (1 - logit) ** self.gamma # focal loss
        return loss.sum(dim=1).mean()

learn.crit = FocalLoss(gamma=2.0)
5 Likes