I’m experimenting with using Dice loss in a multi-class semantic segmentation project, and looking for any feedback on my code/approach, advice for doing it better. I’m still a fastai2 novice, so I’m sure there are many things I’m missing.
Here’s my current version of the custom loss function:
class DiceLoss(nn.Module):
def __init__(self, reduction='mean', eps=1e-7):
super().__init__()
self.eps, self.reduction = eps, reduction
def forward(self, output, targ):
"""
output is NCHW, targ is NHW
"""
eps = 1e-7
# convert target to onehot
targ_onehot = torch.eye(output.shape[1])[targ].permute(0,3,1,2).float().cuda()
# convert logits to probs
pred = self.activation(output)
# sum over HW
inter = (pred * targ_onehot).sum(axis=[-1, -2])
union = (pred + targ_onehot).sum(axis=[-1, -2])
# mean over C
loss = 1. - (2. * inter / (union + self.eps)).mean(axis=-1)
if self.reduction == 'none':
return loss
elif self.reduction == 'sum':
return loss.sum()
elif self.reduction == 'mean':
return loss.mean()
def activation(self, output):
return F.softmax(output, dim=1)
def decodes(self, output):
return output.argmax(1)
Particularly, I noticed that this got about 2x slower than the default cross entropy loss function, and suspect that it’s because I’m doing all this one-hot casting to the target. I was going to try to fix that by making the target masks that go into my databunch be 1-hot in the first place, but my understanding is that I’d have to replace MaskBlock in order for that to be possible. Does that seem right? (If I get to this sooner than responses, I’ll update with my findings…)
Thanks. I had found that repo as well. I’m having trouble with this loss function, though: when I train with loss_func=DiceLoss(), I find that my loss stagnates and doesn’t change after a few batches in the first epoch. On the other hand, if I train against CrossEntropyLoss, and watch dice_loss as a metric, it drops significantly in the first epoch itself. Any advice on how to debug this?
The current version of my code, with dice_loss factored out so I can use it as a metric as well. I got rid of the reduction stuff to see if that was causing problems, but it still doesn’t work.
def dice_loss(output, target, eps=1e-7):
eps = 1e-7
# convert target to onehot
targ_onehot = torch.eye(output.shape[1])[target].permute(0,3,1,2).float().cuda()
# convert logits to probs
pred = F.softmax(output, dim=1)
# sum over HW
inter = (pred * targ_onehot).sum(axis=[0,2,3])
union = (pred + targ_onehot).sum(axis=[0,2,3])
# mean over C
dice = (2. * inter / (union + eps)).mean()
return 1. - dice
class DiceLoss(nn.Module):
def __init__(self, reduction='mean'):
super().__init__()
self.reduction = reduction
def forward(self, output, targ):
"""
output is NCHW, targ is NHW
"""
return dice_loss(output, targ)
def activation(self, output):
return F.softmax(output, dim=1)
def decodes(self, output):
return output.argmax(1)
This should still be accurate with the latest version.
You can also always use a pytorch implementation of it, the only thing special with fastai’s is the activation and decodes, but that’s really not needed whatsoever