I’m experimenting with using Dice loss in a multi-class semantic segmentation project, and looking for any feedback on my code/approach, advice for doing it better. I’m still a fastai2 novice, so I’m sure there are many things I’m missing.

Here’s my current version of the custom loss function:

```
class DiceLoss(nn.Module):
def __init__(self, reduction='mean', eps=1e-7):
super().__init__()
self.eps, self.reduction = eps, reduction
def forward(self, output, targ):
"""
output is NCHW, targ is NHW
"""
eps = 1e-7
# convert target to onehot
targ_onehot = torch.eye(output.shape[1])[targ].permute(0,3,1,2).float().cuda()
# convert logits to probs
pred = self.activation(output)
# sum over HW
inter = (pred * targ_onehot).sum(axis=[-1, -2])
union = (pred + targ_onehot).sum(axis=[-1, -2])
# mean over C
loss = 1. - (2. * inter / (union + self.eps)).mean(axis=-1)
if self.reduction == 'none':
return loss
elif self.reduction == 'sum':
return loss.sum()
elif self.reduction == 'mean':
return loss.mean()
def activation(self, output):
return F.softmax(output, dim=1)
def decodes(self, output):
return output.argmax(1)
```

Particularly, I noticed that this got about 2x slower than the default cross entropy loss function, and suspect that it’s because I’m doing all this one-hot casting to the target. I was going to try to fix that by making the target masks that go into my databunch be 1-hot in the first place, but my understanding is that I’d have to replace `MaskBlock`

in order for that to be possible. Does that seem right? (If I get to this sooner than responses, I’ll update with my findings…)