I’m currently working on
multi-label image segmentation problem, that involves predicting masks for 43 classes. I have the following
I’ve defined my
datablock with the following specification:
datablock is successfully picking up the images and their respective masks, as shown here:
The input and output shapes of the
dataloaders are as follows:
However, when I run the code to train the model:
I encounter a shape mismatch error during evaluation of the loss function:
I would greatly appreciate any guidance on resolving this error.
CrossEntropyLossFlat(axis=1). This loss function should be computed on the channels dimension which should be 1 for the output of your unet NxCxHxW. By default the axis is set to -1. The last dimension is correct for standard image classification models but not for unets.
Thanks Mat for the guidance.
I tried the following code:
learn = unet_learner(dls, resnet34, loss_func=CrossEntropyLossFlat(axis=1))
but now getting the attached cuda assertion error.
- /opt/conda/conda-bld/pytorch_1666642969563/work/aten/src/ATen/native/cuda/Loss.cu:242: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [30,0,0] Assertion
t >= 0 && t < n_classes failed.
Any insight on what might be wrong?
The sizes or shapes of tensors going into your loss function might not be right. You should be able to print out the values using something like this to do further inspection. If I were to guess your model output # of channels vs target indexes may be off by 1.
loss_fn = CrossEntropyLossFlat(axis=1)
print(xb.shape, yb.shape, yb.max())
return loss_fn(xb, yb)
Thank you, Mat, for your valuable advice. The error was caused by the presence of missing unique class codes in the Mask from the
vocab list, specifically codes 27, 35, 36, 38, and 39. To resolve this, I created a dictionary to map these codes to a sequence of integers starting from 0. Then changed the
label_func to take replace original codes. This solution worked perfectly.
I truly appreciate your guidance and assistance.
No worries! Good luck on your project!