I wanted to come up with a segmentation problem to practice the topics covered in Lesson 3. After a bit of thought I realized I had a dataset from work that could be adapted pretty easily.
I work for a company that makes games. We have about 5000 images of creatures in front of the egg they hatched out of that look like this:
For a significant number (~3500) of them we also have a matching “Who’s that Pokemon” style image that looks like this:
Unfortunately, we have lost the source files for many of the 1500 images that are missing the outlined version so it’s not easy to separate the creature from its egg (or to generate a “Who’s that Pokemon” style image with the creature and its egg in different colors).
Well, this looks a lot like a segmentation problem! If I can train a model to classify which parts of the image are “background”, “creature”, and “egg” then I can use the generated masks to extract the creature from the background image and to create “Who’s that Pokemon” style images for the part of the dataset that is missing them. In fact, the segmentation mask is a pretty close approximation for what the “Who’s that Pokemon” style image would look like!
I used the
lesson3-camvid notebook as a jumping off point. And, after a full day of work, I’ve now got a model that correctly predicts about 97% of the pixels.
Example image input:
Example output from the validation set (ground truth / model output):
You can see there are a few detail areas in the image that it doesn’t quite get right (whiskers, between the neck and the paper, around the base of the tail) but it looks pretty good in most cases!
There are also some for which it really doesn’t do well (I think it’s largely due to there not being enough similar training examples). Here’s one such example:
This particular image comes from a set of creatures that come out of balloons instead of eggs. Here was the input image the model was going off of (the drop shadow also made things a bit hard for it I think):
The hardest part was getting my images to conform to the format that
fastai was expecting. The original source images had a 4th channel (alpha transparency) and also had lots of transparent pixels that contained extraneous data in their RGB channels that was invisible due to the alpha channel. And the mask images had anti-aliasing and an alpha channel which had to be removed.
I had a lot of trouble making those be transformations that happened on the fly (independently transforming the x & y images doesn’t seem to be supported in fastai yet) so I ended up running the transforms up front and outputting them to disk before training my model.
I may try to put together a pull request that will make this process of adding independent custom transforms easier for people in the future (it looks like fastai v0.7 may have supported it with an
is_y param passed to the transform function). But there’s a bit more to do than that because the masks are already being shrunk to a single channel prior to the transformations being run.
Edit: I now got a chance to try it on some of the images where we’re missing the “ground truth” to see how it does.
In some cases it does really well!
Unfortunately, in some of our older images, the input images aren’t styled the same way… turns out that one of the things our model learned is about the layer style applied to our images (the egg has a semitransparent white overlay over it in all of the training data). So on the old ones without this it predicts all “creature”