I found my error. In mySSELoss I was summing the losses over the minibatch rather than averaging them.
in https://pytorch.org/docs/master/nn.html#crossentropyloss
it says: The losses are averaged across observations for each minibatch.
Here is the corrected code
def one_hot_embedding(labels, num_classes):
return torch.eye(num_classes)[labels.data.cpu()]
class myCCELoss(nn.Module):
def __init__(self):
super(myCCELoss, self).__init__()
def forward(self, input, target):
y = one_hot_embedding(target, input.size(-1))
logit = F.softmax(input)
loss = -1 * V(y) * torch.log(logit) # cross entropy loss
return loss.sum(dim=1).mean()
class FocalLoss(nn.Module):
def __init__(self, gamma=0, eps=1e-7):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.eps = eps
def forward(self, input, target):
y = one_hot_embedding(target, input.size(-1))
logit = F.softmax(input)
logit = logit.clamp(self.eps, 1. - self.eps)
loss = -1 * V(y) * torch.log(logit) # cross entropy
loss = loss * (1 - logit) ** self.gamma # focal loss
return loss.sum(dim=1).mean()
learn.crit = FocalLoss(gamma=2.0)