Specifying n_out of unet_learner for multi-target mask + category prediction

Hi everyone,

I’m working on a multi-target learner and have a specific data block setup that includes both masks and categories. Here’s a visualisation of my data block:

The dataloaders is functioning perfectly, and when I use show_batch(), the images are displayed with overlaid masks and category titles.

My question is related to using unet_learner() with this setup. The MaskBlock outputs a 26-channel mask with dimensions 512x512, and the CategoryBlock outputs a single class out of 11 possible classes.

How should I define the n_out parameter for unet_learner() in this scenario? Should I set the output tensor size to (26x512x512) + 11? I’m concerned that this will result in a very long tensor. Alternatively, is it possible to specify two heads for unet_learner() so that it can yield both masks and categories separately?

Any guidance or insights you can provide would be greatly appreciated.

Thank you in advance for your help!

Kind regards,
Bilal

1 Like

I tried the following code to customise Unet for multitarget learning but it seems we can’t change the DynamicUnet learner like that:

m=learn.model
avc = ArteryAndViewClassifier(m)
learn2 = deepcopy(learn)
learn2.model = avc
preds, targs = learn2.get_preds(dl=learn2.dls.valid)

When I run the above code, it raises the error that DynamicUnet doesn’t allow changing layers which make sense as it constructs the decoder automatically based on the encoder layers.

Any suggestions where to make the change and how to make two layers as the last layer of the DynamicUnet such that it can predict the Mask and Category at once?

I am not fully sure if I understood your setup correctly but I don’t know how the UNET would at all give you categories? For multi - target there is a good example in Jeremy’s paddy disease Kaggle notebooks. Search for multi target there.

Thank you, @Archaeologist, for your response. Allow me to provide a brief overview of the task at hand. Currently, I am working on predicting coronary artery segments from angiograms using Fastai’s unet_learner, and it has been performing well. While Jeremy’s videos on multi-target learning have been incredibly informative, I encountered a slight difference in my use case, which involves image segmentation rather than classification.

The main idea behind multi-target learning is to improve the model’s predictions across all things it is predicting. Therefore, I decided to explore its potential in our project to enhance segmentation results. To achieve this, I obtained the angles corresponding to the angiogram views and aimed to incorporate them into the unet_learner to predict masks and categories simultaneously. However, it seems that this task is not as straightforward as anticipated, as DynamicUnet doesn’t allow modifications like that. I attempted this by creating a separate class ArteryAndViewClassifier() to configure two heads, but unfortunately, it didn’t work as expected.

My latest attempt involves utilising the n_out parameter in the unet_learner. By setting it to 37 (26 segmentation classes + 11 view classes), I obtained a bsx37x512x512 output which allowed me to replicate the model’s segmentation part through custom loss and metric functions by utilising the first 26 channels of UNET’s output. However, I am currently facing a challenge with handling the remaining 11 channels in the cross-entropy calculation. My current approach involves summing the entire channel mask to obtain one output per class and then applying sqrt() followed by argmax() to choose one out of the 11 outputs.

Here is the code snippet for the custom loss functions and the combo_loss:

def custom_focal_loss(preds, msk_targs, vu_targs, **kwargs):
    return FocalLossFlat(axis=1)(preds[:, :26], msk_targs, **kwargs)

def custom_celf_loss(preds, msk_targs, vu_targs, **kwargs):
    cls_targs = targs[:, 26:]
    cls_preds = torch.argmax(torch.sqrt(torch.sum(cls_targs, dim=(2, 3))), dim=1)[:, None]
    return CrossEntropyLossFlat(reduction='mean')(cls_preds, vu_targs, **kwargs)

def combo_loss(preds, msk_targs, vu_targs, **kwargs):
    alpha = 0.5
    return (alpha * custom_focal_loss(preds, msk_targs, vu_targs, **kwargs)) + (
                (1 - alpha) * custom_celf_loss(preds, msk_targs, vu_targs, **kwargs))

However, I have encountered an error: RuntimeError: “log_softmax_lastdim_kernel_impl” not implemented for ‘Long’. Does this strategy to squeeze the last 26 masks make sense. Any suggestions to resolve this error or implement a better approach.

Thank you again for your valuable help and support!

Understood. This is a complex topic and I do not have an answer for you, unfortunately. What comes to my mind is that Mask-r-cnn has a similar approach in some sense: it’s output is a segmentation mask and a classification result. Note that it’s objectives are different but perhaps it’s code could give you some ideas for your own implementation.

The only other idea I had was going away from the fastai UNET and try a different library such as Segmentation Models Pytorch. Perhaps it has more flexibility? fastai-unet-starter📕 | Kaggle

Thank you, @Archaeologist, for taking the time to guide me to the Mask-RCNN and Segmentation Models Pytorch library. I am glad to see the example of how to use other models with fastai for training segmentation models.

I have made some adjustments to the Cross Entropy Loss part of my ComboLoss. Here’s the refined implementation:

# Custom Focal Loss
def segment_focal_loss(preds, msk_targs, vu_targs, **kwargs):
    # Extract the logits for the segmentation task (26 classes)
    seg_logits = preds[:, :26]
    return FocalLossFlat(axis=1)(seg_logits, msk_targs, **kwargs)

# Custom Cross Entropy Loss for the regression task
def view_cross_entropy_loss(preds, msk_targs, vu_targs, **kwargs):
    # Extract the logits for the regression task (remaining classes)
    cls_logits = preds[:, 26:].cuda()

    # Reduce spatial dimensions (2D) to get a 1D tensor for each sample
    cls_preds = torch.mean(cls_logits, dim=(2, 3)).float()

    return CrossEntropyLossFlat(reduction='mean')(cls_preds, vu_targs, **kwargs)

# Combined Loss (Focal Loss + Cross Entropy Loss)
def combo_loss(preds, msk_targs, vu_targs, **kwargs):
    # Weight factor for the focal loss term (adjust as needed)
    alpha = 0.6

    # Calculate the combined loss using a weighted sum of the two losses
    return (alpha * segment_focal_loss(preds, msk_targs, vu_targs, **kwargs)) + (
        (1 - alpha) * view_cross_entropy_loss(preds, msk_targs, vu_targs, **kwargs)
    )

The model has started to train, and the initial results seem promising. However, I’m still facing some challenges with the Cross Entropy Loss part. My logic of converting the entire mask of shape 512x512 into a single numerical score by simply taking the mean of values to effectively guide the segmentation model in predicting class scores is not working very well.

The View Cross Entropy Loss is making the overall loss higher, which is making it difficult for the model to learn both tasks simultaneously. Looking forward to suggestions on how to improve the Cross Entropy Loss logic.

Again, this is not my specialty, but if you create the mean you are ranking all pixel classes equally, including the background class which you are probably not interested in. I wonder what happens if you overweight the classes you are interested in?

1 Like

Great observation! Taking the mean of channels reserved for predicting the class labels might not be the most suitable approach for predicting the view of the angiogram, as it condenses the information, making it challenging for the DynamicUnet model to guide parameter updates effectively. I am considering another direction to use two independent models: one for segmentation and another for classification, and then combine them through a combiner model. The overall role of the combiner model would be to allow fine-tuning the parameters of these models by learning from each other for a few epochs and using a combo loss, which could effectively guide parameter updates of individual models within the combiner class. I am not sure if we can improve the models this way. While I do not require the classification and combiner models itself, I can use them to enhance the performance of the segmentation model through multi-target learning.

The combiner model technique worked flawlessly, enabling me to simultaneously fine-tune both the segmentation and classification models. Although the model isn’t flawless, it has made significant progress in addressing the problem of predicting incorrect segments for specific views.

1 Like