Training UNet - Dice coefficient > 1 for Segmentation

Hi Team,

I am currently training a UNet for the Severstal Kaggle competition.

For the competition, I’ve chosen the Dice coefficient as the metric. When creating my learner and training it the dice coefficient is more often than not above 1.

The code is as follows:

arch = models.resnet34
learn = unet_learner(data, arch, metrics=[dice])



learn.fit_one_cycle(5, slice(lr), pct_start=0.9)

The output is:

|epoch|train_loss|valid_loss|dice |time |
|0 |0.105920|0.100106 |0.556928|02:34|
|1 |0.100063|0.094888 |1.334559|02:30|
|2 |0.090685|0.077091 |1.150504|02:33|
|3 |0.085975|0.071968 |1.518277|02:31|
|4 |0.066519|0.061476 |1.759492|02:31|

For the learner I am using a single channel mask with classes [0,1,2,3,4].

Why is this metric above 1? Does it require refactoring to account for multi labels?

Hi, I had the same problem with unet_learner, I wrote another function for that.

def dice_(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8)->Rank0Tensor:
    n = targs.shape[0]
    dice_sum = torch.as_tensor([0.], dtype=torch.float32, device=targs.device)
    for i in range(1, targs.max()+1):
        input_flatten = input[:,i-1,:,:].flatten(1,-1).float()
        targs_flatten = targs.flatten(1,-1).float()
        input_flatten[input_flatten>0.5], input_flatten[input_flatten<=0.5] = 1., 0.
        targs_flatten[targs_flatten!=i], targs_flatten[targs_flatten==i] = 0, 1
        intersect = (input_flatten * targs_flatten).sum(dim=1)
        union = (input_flatten + targs_flatten).sum(dim=1)
        if not iou: l = 2. * intersect / union
        else: l = intersect / (union-intersect+eps)
        l[union == 0.] = 1
        dice_sum += l.mean()
    return dice_sum / targs.max()

learn.metrics = [dice_]

1 Like

Thank you @matlihan - I’ll digest and see how this works.

I was also facing the same issue, after digging for a while came to know that the built-in dice metric was not correct for segmentation with more than 2 classes. I think it was built for binary classification of pixels and that to classes being 0 and 1. So I made a new dice metric by changing the existing one.

Dice coefficient for multi_class_segmentation

def dice_multi(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8)->Rank0Tensor:
n = targs.shape[0]
targs = targs.squeeze(1)
input = input.argmax(dim=1).view(n,-1)
targs = targs.view(n,-1)
targs1 = (targs>0).float()
input1 = (input>0).float()
ss = (input == targs).float()
intersect = (ss * targs1).sum(dim=1).float()
union = (input1+targs1).sum(dim=1).float()
if not iou: l = 2. * intersect / union
else: l = intersect / (union-intersect+eps)
l[union == 0.] = 1.
return l.mean()
