Feedback on using custom dice loss in multi-class semantic segmentation

I’m experimenting with using Dice loss in a multi-class semantic segmentation project, and looking for any feedback on my code/approach, advice for doing it better. I’m still a fastai2 novice, so I’m sure there are many things I’m missing.

Here’s my current version of the custom loss function:

class DiceLoss(nn.Module):
    def __init__(self, reduction='mean', eps=1e-7):
        super().__init__()
        self.eps, self.reduction = eps, reduction
        
    def forward(self, output, targ):
        """
        output is NCHW, targ is NHW
        """
        eps = 1e-7
        # convert target to onehot
        targ_onehot = torch.eye(output.shape[1])[targ].permute(0,3,1,2).float().cuda()
        # convert logits to probs
        pred = self.activation(output)
        # sum over HW
        inter = (pred * targ_onehot).sum(axis=[-1, -2])
        union = (pred + targ_onehot).sum(axis=[-1, -2])
        # mean over C
        loss = 1. - (2. * inter / (union + self.eps)).mean(axis=-1)
        if self.reduction == 'none':
            return loss
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'mean':
            return loss.mean()

    def activation(self, output):
        return F.softmax(output, dim=1)
    
    def decodes(self, output):
        return output.argmax(1)

Particularly, I noticed that this got about 2x slower than the default cross entropy loss function, and suspect that it’s because I’m doing all this one-hot casting to the target. I was going to try to fix that by making the target masks that go into my databunch be 1-hot in the first place, but my understanding is that I’d have to replace MaskBlock in order for that to be possible. Does that seem right? (If I get to this sooner than responses, I’ll update with my findings…)

1 Like

I saw this implementation on pytorch goodies

def dice_loss(true, logits, eps=1e-7):
    """Computes the Sørensen–Dice loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the dice loss so we
    return the negated dice loss.
    Args:
        true: a tensor of shape [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
        eps: added to the denominator for numerical stability.
    Returns:
        dice_loss: the Sørensen–Dice loss.
    """
    num_classes = logits.shape[1]
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(logits)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
    else:
        true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        probas = F.softmax(logits, dim=1)
    true_1_hot = true_1_hot.type(logits.type())
    dims = (0,) + tuple(range(2, true.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    dice_loss = (2. * intersection / (cardinality + eps)).mean()
    return (1 - dice_loss)
1 Like

Thanks. I had found that repo as well. I’m having trouble with this loss function, though: when I train with loss_func=DiceLoss(), I find that my loss stagnates and doesn’t change after a few batches in the first epoch. On the other hand, if I train against CrossEntropyLoss, and watch dice_loss as a metric, it drops significantly in the first epoch itself. Any advice on how to debug this?

The current version of my code, with dice_loss factored out so I can use it as a metric as well. I got rid of the reduction stuff to see if that was causing problems, but it still doesn’t work.

def dice_loss(output, target, eps=1e-7):
    eps = 1e-7
    # convert target to onehot
    targ_onehot = torch.eye(output.shape[1])[target].permute(0,3,1,2).float().cuda()
    # convert logits to probs
    pred = F.softmax(output, dim=1)
    # sum over HW
    inter = (pred * targ_onehot).sum(axis=[0,2,3])
    union = (pred + targ_onehot).sum(axis=[0,2,3])
    # mean over C
    dice = (2. * inter / (union + eps)).mean()
    return 1. - dice
    

class DiceLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
        
    def forward(self, output, targ):
        """
        output is NCHW, targ is NHW
        """
        return dice_loss(output, targ)

    def activation(self, output):
        return F.softmax(output, dim=1)
    
    def decodes(self, output):
        return output.argmax(1)

Did you (or anyone else) succeeded on this?

from fastai.vision.all import *

__all__ = ['DiceLoss', 'CombinedLoss']

def _one_hot(x, classes, axis=1):
    "Target mask to one hot"
    return torch.stack([torch.where(x==c, 1,0) for c in range(classes)], axis=axis)

class DiceLoss:
    "Dice coefficient metric for binary target in segmentation"
    def __init__(self, axis=1, smooth=1): 
        store_attr()
    def __call__(self, pred, targ):
        targ = _one_hot(targ, pred.shape[1])
        pred, targ = flatten_check(self.activation(pred), targ)
        inter = (pred*targ).sum()
        union = (pred+targ).sum()
        return 1 - (2. * inter + self.smooth)/(union + self.smooth)
    
    def activation(self, x): return F.softmax(x, dim=self.axis)
    def decodes(self, x):    return x.argmax(dim=self.axis)
    

class CombinedLoss:
    "Dice and Focal combined"
    def __init__(self, axis=1, smooth=1, alpha=1):
        store_attr()
        self.focal_loss = FocalLossFlat(axis=axis)
        self.dice_loss =  DiceLoss(axis, smooth)
        
    def __call__(self, pred, targ):
        return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ)
    
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)
1 Like

Just bumping it up. Anybody tried to figure it out how to use custom loss of Dice and Focal with fastai latest?

This should still be accurate with the latest version.

You can also always use a pytorch implementation of it, the only thing special with fastai’s is the activation and decodes, but that’s really not needed whatsoever

1 Like

Yes, I thought so as well. Just now did 2 epoch run with 15k images of satellite data for segmenting buildings and it is working very nice!