(UNet) How to get 4-channel output?

Hello fastai community,

I’m having trouble getting a Unet model to output 4-channel images rather than 1-channel outputs.

Input images are RGB (3 channel). My labels are 4 non-disjoint bitmasks, so I put 0/1 pixel labels in 4 channels.

Input image shape: (3, height, width)
Target label shape: (4, height, width)

My loss function is crashing because the input to the loss function is 1-channel tensor of shape (1, height, width), which mismatches the size of the target (4, height, width).

How can I get my unet to give 4-channel output to match my labels?

I’m a bit rusty on this topic but:

  • wouldn’t you just have to modify the last conv layer to output 4 channels instead of 1?

  • if using dynamic unet, does the n_classes parameter take care of that?

As discussed in this thread, the fastai SegmentationItemList doesn’t seem to handle multi-label (i.e. non-disjoint) masks. The model will have 1 output per class, but then an argmax is applied to convert to single channel for comparison with target which is expected to be single channel with values in range(0,n_classes).
I have now enhanced the code I posted there to support correct display of multi-label images. The current code I have is (removing stuff specific to my use so may not work as is):

# Can't use cross_entropy_loss as in SegmentationLabelList
# If model has a final sigmoid layer then use the non-logit BCE instead
def bce_logits_floatify(input, target, reduction='mean'):
    return F.binary_cross_entropy_with_logits(input, target.float(), reduction=reduction)

class MultiLabelMaskList(SegmentationLabelList):
    def __init__(self, items:Iterator, classes, **kwargs):
        super().__init__(items, classes, **kwargs)
        self.loss_func = bce_logits_floatify

    def open(self, fn)
        im = open_mask(fn)
        return MultiLabelMask(im.px)

    def analyze_pred(self, pred, thresh:float=0.5):
        raise NotImplementedError()
        #return (pred.sigmoid()>thresh)[None]

    def reconstruct(self, t:Tensor): return MultiLabelMask(t)

from matplotlib.colors import ListedColormap
class MultiLabelMask(ImageSegment):
    CMAP='tab10' # Default color map to use

    def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
             mask_cmap:str=None, alpha:float=0.5, **kwargs):
        "Show the `MultiLabelMask` on `ax` with color map `cmap`, each layer uses a color from `cmap`."
        if ax is None: fig,ax = plt.subplots(figsize=figsize)
        n_channels = self.px.shape[0]
        cmap = ifnone(mask_cmap, self.CMAP)
        if isinstance(cmap, str): cmap = plt.cm.get_cmap(cmap)
        assert cmap.N > n_channels, f"Not enough colors in color map ({cmap.N}) for classes in data ({n_channels})."
        for i in range(n_channels):
            cm = ListedColormap(['#00000000', cmap.colors[i]])
            ax.imshow(image2np(self.px[i:i+1]), cmap=cm, alpha=alpha, **kwargs)
        if hide_axis: ax.axis('off')
        if title: ax.set_title(title)
    def _repr_image_format(self, format_str):
        with BytesIO() as str_buffer:
            fig, ax = plt.subplots()
            fig.savefig(str_buffer, format=format_str)
            plt.close(fig) # Prevent display in jupyter
            return str_buffer.getvalue()

I haven’t actually used image files, so the MultiLabelMask.open is just off the top of my head and may not work, or should be replaced if using RLE. In particular be sure if using files that open_mask isn’t messing with them. Your masks should be Channel x Height x Width float tensors (they have to be floats in spite of just being 0/1 for transforms to work, need to look at this and try and just convert when needed, the transform handling of ImageSegment is inherited so should be sensible for masks otherwise).
This works for display and training. I haven’t yet sorted out interpretation, hence the NotImplementedError in analyze_pred, the commented code is something like what I think would be needed, but is not tested.


Thanks, Tom! Your loss function fixed my issue. Cross entropy was flattening my channels dimensions, which lead to a size mismatch by a factor of num_channels.

P.S.Here’s my run-length encoding open function for multi-channel labels for reference:

   def open(self, id_rles):
            image_id, rles = id_rles[0], id_rles[1:]
            shape = open_image(self.path/image_id).shape[-2:]       
            final_mask = torch.zeros((4, *shape))
            for k, rle in enumerate(rles):
                if isinstance(rle, str):
                    mask = open_mask_rle(rle, shape).px.permute(0, 2, 1)
                    # Assign particular label type to a particular channel
                    final_mask[k, :] = mask
            return ImageSegment(final_mask)

open function adapted from this kaggle kernel by Mayur Kulkarni: https://www.kaggle.com/mayurkulkarni/fastai-simple-model-0-88-lb

That will likely cause errors when you attempt to display the masks or at least not display them correctly. ImageSegment doesn’t properly handle display of 4 channel masks (this doesn’t affect training). The mask is passed to matplotlib’s imshow which interprets a 4-channel image as being RGBA and will then raise an error about non-contiguous data (at least in some cases). MultiLabelMask resolves this and provides clear display of masks.

I created an implementation of multi-label masks based on your code which works properly for my use case. Look at the classes MultiLabelSegmentationLabelList and MultiLabelImageSegment in this kernel if you are interested.