As discussed in this thread, the fastai SegmentationItemList
doesn’t seem to handle multi-label (i.e. non-disjoint) masks. The model will have 1 output per class, but then an argmax is applied to convert to single channel for comparison with target which is expected to be single channel with values in range(0,n_classes)
.
I have now enhanced the code I posted there to support correct display of multi-label images. The current code I have is (removing stuff specific to my use so may not work as is):
# Can't use cross_entropy_loss as in SegmentationLabelList
# If model has a final sigmoid layer then use the non-logit BCE instead
def bce_logits_floatify(input, target, reduction='mean'):
return F.binary_cross_entropy_with_logits(input, target.float(), reduction=reduction)
class MultiLabelMaskList(SegmentationLabelList):
def __init__(self, items:Iterator, classes, **kwargs):
super().__init__(items, classes, **kwargs)
self.loss_func = bce_logits_floatify
def open(self, fn)
im = open_mask(fn)
return MultiLabelMask(im.px)
def analyze_pred(self, pred, thresh:float=0.5):
raise NotImplementedError()
#return (pred.sigmoid()>thresh)[None]
def reconstruct(self, t:Tensor): return MultiLabelMask(t)
from matplotlib.colors import ListedColormap
class MultiLabelMask(ImageSegment):
CMAP='tab10' # Default color map to use
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
mask_cmap:str=None, alpha:float=0.5, **kwargs):
"Show the `MultiLabelMask` on `ax` with color map `cmap`, each layer uses a color from `cmap`."
if ax is None: fig,ax = plt.subplots(figsize=figsize)
n_channels = self.px.shape[0]
cmap = ifnone(mask_cmap, self.CMAP)
if isinstance(cmap, str): cmap = plt.cm.get_cmap(cmap)
assert cmap.N > n_channels, f"Not enough colors in color map ({cmap.N}) for classes in data ({n_channels})."
for i in range(n_channels):
cm = ListedColormap(['#00000000', cmap.colors[i]])
ax.imshow(image2np(self.px[i:i+1]), cmap=cm, alpha=alpha, **kwargs)
if hide_axis: ax.axis('off')
if title: ax.set_title(title)
def _repr_image_format(self, format_str):
with BytesIO() as str_buffer:
fig, ax = plt.subplots()
self.show(ax)
fig.savefig(str_buffer, format=format_str)
plt.close(fig) # Prevent display in jupyter
return str_buffer.getvalue()
I haven’t actually used image files, so the MultiLabelMask.open
is just off the top of my head and may not work, or should be replaced if using RLE. In particular be sure if using files that open_mask
isn’t messing with them. Your masks should be Channel x Height x Width float tensors (they have to be floats in spite of just being 0/1 for transforms to work, need to look at this and try and just convert when needed, the transform handling of ImageSegment
is inherited so should be sensible for masks otherwise).
This works for display and training. I haven’t yet sorted out interpretation, hence the NotImplementedError in analyze_pred
, the commented code is something like what I think would be needed, but is not tested.