Image segmentation with Unet / ResNet.
src_learner = unet_learner(
src_dataloader,
resnet34,
n_out=256,
loss_func=FocalLossFlat(axis=1),
#loss_func=DiceLoss(axis=1),
#loss_func=CrossEntropyLossFlat(axis=1),
)
CrossEntropyLossFlat (the default) and FocalLossFlat all make good progress across training epochs and produce models that behave well.
However, DiceLoss makes no progress at all in training. Validation loss is frozen at some high value:
epoch train_loss valid_loss time
0 1756.761841 1649.900269 00:04
1 1676.365967 1649.900269 00:04
2 1655.060669 1649.900269 00:04
3 1644.191162 1649.900269 00:04
4 1639.195435 1649.900269 00:04
5 1631.809692 1649.900269 00:04
The predictions generated by the model (when trained with the default loss function) are regular masks:
>>> zzz = src_learner.predict("/some/image/file")
>>> zzz_mask = zzz[0]
>>> zzz_mask.shape
torch.Size([224, 224])
>>> type(zzz_mask)
fastai.torch_core.TensorMask
A typical predicted mask may look like this:
However, the documentation suggests that DiceLoss could actually be used for segmentation. Moreover, it claims that it could be combined with FocalLossFlat:
However, since DiceLoss does not budge at all during training, I’m not sure how this would work. Also, I guess DiceLoss would have to be used with reduction="mean"
when combined with FocalLossFlat for this type of data, to keep their values comparable, since FocalLossFlat produces values between 0 and 1.