Image Segmentation: Image-level loss interpretation

First post! I’ve taken the DL1 course (been meaning to take DL2 but have so far been busy using what I learned at work) and have been using it to work in the AI team of my company full-time. Thanks for the great course!

I’m using the lesson 3 camvid notebook to create an auto-avatar cropping feature for our product (users can submit hand or computer-drawn images). Currently, the design team manually creates masks for these when necessary (IE: photo of a drawing) where they alpha-mask the background pixels.

Luckily, they’ve saved all the submitted files, and the resulting PNG outputs. I’ve used these to generate a set of labels for this dataset and have adapted the CAMVID notebook for this data.

Here is an example where everything works correctly (input + mask on the left, predicted segmentation on the right):

I’ve been able to train a segmentation model which achieves 94% accuracy! Interestingly, it seems like most of the losses are due to mismatches between the automatically generated labels and the raw inputs being misaligned due to the input not being square, or the input image having a lot of “padding” around the actual avatar to be selected.

example:



I’m trying to figure out a way to cull those images from my dataset in order to get a better idea of it’s real accuracy. Ideally something like the Classification Cleaner Widget from lesson1 - however even just outputting a ranked list of losses would suffice.

additionally this would help me catch the few cases where my auto-generated labels seem to be picking up the wrong file:

Interestingly, these badly labeled images seem sparse enough that the model is able to correctly segment the images, however I guess these mislabeled images are hurting the accuracy metric of the model.

My issue is that when I create an interpretation using
SegmentationInterprretation.from_learner(learn)
and then run
top_losses, top_idxs = interp.top_losses()
the results seem to be at the pixel level rather than aggregated to the image level. I imagine I could write a function to sum the losses of the indexes of the pixels that belong to each image, but before I do that I want to make sure I’m not missing something simpler…

1 Like