Image Segmentation Shape Mismatch Error during Loss Evaluation

Hi,

I’m currently working on multi-label image segmentation problem, that involves predicting masks for 43 classes. I have the following vocab :

I’ve defined my datablock with the following specification:

The datablock is successfully picking up the images and their respective masks, as shown here:

The input and output shapes of the dataloaders are as follows:

However, when I run the code to train the model:

I encounter a shape mismatch error during evaluation of the loss function:

I would greatly appreciate any guidance on resolving this error.
Thank you.
Bilal

2 Likes

Try CrossEntropyLossFlat(axis=1). This loss function should be computed on the channels dimension which should be 1 for the output of your unet NxCxHxW. By default the axis is set to -1. The last dimension is correct for standard image classification models but not for unets.

Unet example:

1 Like

Thanks Mat for the guidance.

I tried the following code:

learn = unet_learner(dls, resnet34, loss_func=CrossEntropyLossFlat(axis=1))
learn.fine_tune(10)

but now getting the attached cuda assertion error.

  1. /opt/conda/conda-bld/pytorch_1666642969563/work/aten/src/ATen/native/cuda/Loss.cu:242: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [30,0,0] Assertion t >= 0 && t < n_classes failed.

Any insight on what might be wrong?

The sizes or shapes of tensors going into your loss function might not be right. You should be able to print out the values using something like this to do further inspection. If I were to guess your model output # of channels vs target indexes may be off by 1.

loss_fn = CrossEntropyLossFlat(axis=1)
def loss_wrapped(xb,yb):
    print(xb.shape, yb.shape, yb.max())
    return loss_fn(xb, yb)
... unet_learner(...loss_func=loss_wrapped...)
1 Like

Thank you, Mat, for your valuable advice. The error was caused by the presence of missing unique class codes in the Mask from the vocab list, specifically codes 27, 35, 36, 38, and 39. To resolve this, I created a dictionary to map these codes to a sequence of integers starting from 0. Then changed the label_func to take replace original codes. This solution worked perfectly.

I truly appreciate your guidance and assistance.

Best regards,
Bilal

1 Like

No worries! Good luck on your project!

1 Like