I’m experimenting with using Dice loss in a multi-class semantic segmentation project, and looking for any feedback on my code/approach, advice for doing it better. I’m still a fastai2 novice, so I’m sure there are many things I’m missing.
Here’s my current version of the custom loss function:
class DiceLoss(nn.Module): def __init__(self, reduction='mean', eps=1e-7): super().__init__() self.eps, self.reduction = eps, reduction def forward(self, output, targ): """ output is NCHW, targ is NHW """ eps = 1e-7 # convert target to onehot targ_onehot = torch.eye(output.shape)[targ].permute(0,3,1,2).float().cuda() # convert logits to probs pred = self.activation(output) # sum over HW inter = (pred * targ_onehot).sum(axis=[-1, -2]) union = (pred + targ_onehot).sum(axis=[-1, -2]) # mean over C loss = 1. - (2. * inter / (union + self.eps)).mean(axis=-1) if self.reduction == 'none': return loss elif self.reduction == 'sum': return loss.sum() elif self.reduction == 'mean': return loss.mean() def activation(self, output): return F.softmax(output, dim=1) def decodes(self, output): return output.argmax(1)
Particularly, I noticed that this got about 2x slower than the default cross entropy loss function, and suspect that it’s because I’m doing all this one-hot casting to the target. I was going to try to fix that by making the target masks that go into my databunch be 1-hot in the first place, but my understanding is that I’d have to replace
MaskBlock in order for that to be possible. Does that seem right? (If I get to this sooner than responses, I’ll update with my findings…)