Learner.decode_batch for unet models

I am trying to speed up inference for a unet model, based on this post.

This is what I am doing:

loaders = learner.dls.test_dl(items)
for batch in loaders:
        with learner.no_bar(), learner.no_logging(), torch.no_grad():
            inputs = batch[0]
            masks = learner.model(inputs)
            masks = learner.dls.decode_batch((inputs, masks), max_n=len(inputs))

The result masks L list type has items = batch_size for the decode_batch.

However, every item in masks has 2 elements. One of them is a TensorImage and the other is a TensorMask. However the tensor mask is of type torch.float32 and does not have the expected values.


Credit to @muellerzr for answering my question.
The TensorImage represents the probabilities that a pixel belongs to one of the 3 channels in the image (each representing a segmentation class).