Image Segmentation - understanding inference

I am doing single prediction on a UNET learner for a multi class segmentation as below

pred_class,pred_idx,outputs = learn.predict(image)

Can anybody help understand , how to relate the output probability map to the classes.

suppose i have 11 classes. How do i label the segments in the output probability map with the class names ?

How are the classes and the pixel values on the output probability map mapped ?.

1 Like

Hi,

Typically, the classes are numbered. In your case: 0 to 10. You can either produce a single color image with pixel values only in the range 0 to 10. Or you can produce 11 gray level images where in each image, pixel value 1 refers to the class.
Some of this is from your training data. How is your training data defined?

thank you for looking into this.

pred_class,pred_idx,outputs = learn.predict(image)

pred_class is as below

How to map the class names against the various segments ?

argmax should give you the class id and learn.data.classes should give you the ordered class names for each id (assuming you followed lesson notebooks to create dataset).

ll = ill.label_from_func(get_y_func, classes=codes) e.g. codes here is a list of ordered classes to predict.

Revisiting this question while working on new Interpretation module. I guess you asked for something like this:

Please feel free to give feedback to create a better interpretation workflow in fastai :slight_smile:

3 Likes

This is super amazing. Thank you for sharing. Is it possible to share the link to this notebook ?.

Taking this interpretation further , i would need to separate these segments and save each of those in different label directories , eg: a skirt directory with all the skirt segments , pants directory with all the pants segments etc. thanks

1 Like

I will share it as soon as possible :slight_smile:

1 Like
cmap='tab20'
fig,axes=plt.subplots(1,2,figsize=(sz,sz))

#image
im=axes[0].imshow(image2np(t),cmap=cmap)

#labels
c = len(classes)
n = math.ceil(np.sqrt(c))
axes[1].imshow(np.array(range(c)).reshape(n,n), cmap='tab20')
for i,l in enumerate(classes):
    div,mod=divmod(i,n)
    axes[1].text(mod, div, f"{l}", ha='center', color='white', fontdict={'size':sz})
axes[1].set_yticks([]);axes[1].set_xticks([]);

This code will generate:

or with this

def interp_show(ims:ImageSegment, classes:Collection, sz:int=20, cmap='tab20', c2i:dict=c2i):
    'show ImageSegment with given the classes'
    fig,axes=plt.subplots(1,2,figsize=(sz,sz))

    #image
    mask = (torch.cat([ims.data==i for i in [c2i[c] for c in classes]])
            .max(dim=0)[0][None,:]).long()
    masked_im = image2np(ims.data*mask)
    im=axes[0].imshow(masked_im, cmap=cmap)
    
    #labels
    labels = list(np.unique(masked_im))
    c = len(labels); n = math.ceil(np.sqrt(c))
    label_im = labels + [np.nan]*(n**2-c)
    label_im = np.array(label_im).reshape(n,n)
    axes[1].imshow(label_im, cmap=cmap)
    
    i2c = {i:c for c,i in c2i.items()}
    
    for i,l in enumerate([i2c[l] for l in labels]):
        div,mod=divmod(i,n)
        axes[1].text(mod, div, f"{l}", ha='center', color='white', fontdict={'size':sz})
    axes[1].set_yticks([]);axes[1].set_xticks([]);

Note: Tricky part is to have unique values same in both mask image and label image otherwise cmap will give different values. Also cmap squishes values to be mappable so if len(classes) > 20 then classes with close idxs, say 4 and 5, will be mapped to the same color.

10 Likes

Agree, finding it tricky to handle the cmap values between mask image and label image . My values are not coming in correctly.

With the code I shared it should be correctly mapped.

Also another note is that max number of classes is 20 in qualitative colormaps of matplotlib. So, if you have 21 classes 2 classes will be mapped to the same color.

1 Like

A notebook showing upcoming SegmentationInterpretation. Still in PR process https://github.com/fastai/fastai/pull/2115.

4 Likes

in this function:interp_show(), can i understand ImageSegmention(t) is predict image, and ImageSegmention(interp.y_true[457]) is true image?

Hi @kcturgutlu, nice to read your post. It’s annoying me to spend some time to figure out the cmap issue. But I don’t understand how do you handle with this in your code, could you explain more? Thank you!