Unet_learner for segmentation: DiceLoss makes no progress in training

Image segmentation with Unet / ResNet.

src_learner = unet_learner(
    src_dataloader,
    resnet34,
    n_out=256,
    loss_func=FocalLossFlat(axis=1),
    #loss_func=DiceLoss(axis=1),
    #loss_func=CrossEntropyLossFlat(axis=1),
)

CrossEntropyLossFlat (the default) and FocalLossFlat all make good progress across training epochs and produce models that behave well.

However, DiceLoss makes no progress at all in training. Validation loss is frozen at some high value:

epoch	train_loss	valid_loss	time
0	1756.761841	1649.900269	00:04
1	1676.365967	1649.900269	00:04
2	1655.060669	1649.900269	00:04
3	1644.191162	1649.900269	00:04
4	1639.195435	1649.900269	00:04
5	1631.809692	1649.900269	00:04

The predictions generated by the model (when trained with the default loss function) are regular masks:

>>> zzz = src_learner.predict("/some/image/file")
>>> zzz_mask = zzz[0]
>>> zzz_mask.shape
torch.Size([224, 224])
>>> type(zzz_mask)
fastai.torch_core.TensorMask

A typical predicted mask may look like this:

output

However, the documentation suggests that DiceLoss could actually be used for segmentation. Moreover, it claims that it could be combined with FocalLossFlat:

However, since DiceLoss does not budge at all during training, I’m not sure how this would work. Also, I guess DiceLoss would have to be used with reduction="mean" when combined with FocalLossFlat for this type of data, to keep their values comparable, since FocalLossFlat produces values between 0 and 1.

Since the training loss is improving while the validation loss isn’t, I believe this might be a case of the model overfitting on the training data.

Admittedly this might be a naive take on my part, since I don’t really have that much experience working with the U-net architecture.

I’m not sure but you set n_out=256 it might be the couse? default is n_out=None

I doubt that’s the case, since the model makes good progress in training if I use FocalLossFlat or CrossEntropyLossFlat.

I often see a combination of Focal Loss and Dice Loss used (but cannot remember if I ever saw Dice loss alone). Simply spoken, they are just added up. For example in this Kaggle notebook:

Search for the “Loss function” headline there, as it also provides an explanation: The author claims that Dice alone makes training unstable.

I see that as well. The thing is, in my case DiceLoss does not move at all in validation, at least when used alone. I could see how using a combination of functions may push the network in a good direction, but it’s still puzzling why the validation loss is the same at all epochs with DiceLoss alone.

I will try to create a custom method and explore what’s really going on when DiceLoss is invoked from my model.

I’m also thinking to try and recreate DiceLoss from scratch, if I have time, if the previous step does not provide an explanation.