Semantic Segmentation runtime error at loss function

I am using a costume model for segmentation (SETRModel). The model output shape is (nBatch, 256, 256) and the code below confirms it (note that the channel is squeezed out). The target shape is the same (It’s a PILMask).

When I start training, I get a runtime error (see below) related to the loss function. What am I doing wrong?


    size = 480
    half= (256, 256) 
    splitter = FuncSplitter(lambda o: Path(o).parent.name == 'validation')

    dblock = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                   get_items=get_relevant_images,
                   splitter=splitter,
                   get_y=get_mask, 
                   item_tfms=Resize((size,size)),
                   batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])

    dls = dblock.dataloaders(path/'images', bs=4)

    model = SETRModel(patch_size=(32, 32), 
                in_channels=3, 
                out_channels=1, 
                hidden_size=1024, 
                num_hidden_layers=8, 
                num_attention_heads=16, 
                decode_features=[512, 256, 128, 64])


# Create a Learner using a custom model
loss = nn.BCEWithLogitsLoss()
learn = Learner(dls, model, loss_func=loss, lr=1.0e-4, cbs=callbacks, metrics=[Dice()])


# Let's test and make sure the loss function is happy with its inputs
learn.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

t1 = torch.rand(4, 3, 256, 256).to(device)
print("input: " + str(t1.shape))

pred = learn.model(t1).to(device)
print("output: " + str(pred.shape))

# prints this:
# input: torch.Size([4, 3, 256, 256])
# output: torch.Size([4, 256, 256])

target = next(iter(learn.dls.train))[1]
target = target.type(torch.float32).to(device)
target.size(), pred.size()

# prints this:
# (torch.Size([4, 256, 256]), torch.Size([4, 256, 256]))

loss(pred, target)

# prints this:
# TensorMask(0.6844, device='cuda:0', grad_fn=<AliasBackward>)

# so, the loss function is happy with its inputs

learn.fine_tune(50)

# prints this:
# ---------------------------------------------------------------------------
# RuntimeError                              Traceback (most recent call last)
# <ipython-input-114-0e514c73651a> in <module>()
# ----> 1 learn.fine_tune(50)

# 19 frames
# /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
#    2827 pixel_shuffle = _add_docstr(torch.pixel_shuffle, r"""
#    2828 Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a
# -> 2829 tensor of shape :math:`(*, C, H \times r, W \times r)`.
#    2830 
#    2831 See :class:`~torch.nn.PixelShuffle` for details.

# RuntimeError: result type Float can't be cast to the desired output type Long

Answering my own question (hoping it helps someone else). I had to use my custom loss function.

def my_loss(input, target):
  inp = input.type(torch.FloatTensor).clone()
  t = target.type(torch.FloatTensor).clone()
  return nn.BCEWithLogitsLoss()(inp, t)