DataBlock transforms for Image Inpainting

Hey everyone,

I am trying to build a DataBlock for an Image Inpainting GAN (to be precise the Generator).
For image inpainting tasks parts of the input images will be erased, the output image should stay the same.
So I my DataBlock looks like this:

  dblock=DataBlock(blocks=(ImageBlock, ImageBlock),
                 get_items=get_image_files, 
                 get_y=lambda x: path/"images"/x.name,
                 splitter=RandomSplitter(),
                 item_tfms=Resize(size),
                 batch_tfms=[*aug_transforms(),
                             Normalize.from_stats(*imagenet_stats),
                             RandomErasing(p=1.)])
dblock.summary(path/"images")

A couple of questions:

  1. Is this the right approach to generate a diverse dataset or should I create my masked images beforehand?
  2. The augmentation of random erasing parts is currently performed on both input and output, can I restrict it (and only this augmentation) to only the input?
  3. Currently the docs state that Random Erasing should be done after a Normalization. Am I right to assume that this means its order should be after Normalize.from_stats(*imagenet_stats), in the batch_tfms?
  4. Is there a way to set RandomErasing to not generate noise, but one color (e.g. black=0)?

Thanks for helping me get on the right track here. Especially to @muellerzr for his notebook on GANs in fastai (2).
Thanks Where thanks Is Due! :kissing_heart:

1 Like

Hello everyone,

I think it would be helpful to give an update. So I learned that transforms are performed on both x and y. But I saw this neat little trick, were by typing the input to a dummy class a transformation will be performed only on that dummy class. So a CutOut Transform looks like this:

class PILImageInput(PILImage):
    pass


class RandomCutout(RandTransform):
    "Picks a random scaled crop of an image and resize it to `size`"
    split_idx = None

    def __init__(self, min_n_holes=2, max_n_holes=5, min_length=0.05, max_length=0.4, **kwargs):
        super().__init__(**kwargs)
        self.min_n_holes = min_n_holes
        self.max_n_holes = max_n_holes
        self.min_length = min_length
        self.max_length = max_length

    def encodes(self, x: PILImageInput):
        """
        Note that we're accepting our dummy PILImageInput class
        fastai2 will only pass images of this type to our encoder.
        This means that our transform will only be applied to input images and won't
        be run against output images.
        """

        n_holes = np.random.randint(self.min_n_holes, self.max_n_holes)
        pixels = np.array(x)  # Convert to mutable numpy array. FeelsBadMan
        h, w = pixels.shape[:2]

        for n in range(n_holes):
            h_length = np.random.randint(self.min_length*h, self.max_length*h)
            w_length = np.random.randint(self.min_length*w, self.max_length*w)
            h_y = np.random.randint(0, h)
            h_x = np.random.randint(0, w)
            y1 = int(np.clip(h_y - h_length / 2, 0, h))
            y2 = int(np.clip(h_y + h_length / 2, 0, h))
            x1 = int(np.clip(h_x - w_length / 2, 0, w))
            x2 = int(np.clip(h_x + w_length / 2, 0, w))

            pixels[y1:y2, x1:x2, :] = 0

        return Image.fromarray(pixels, mode='RGB')

(Credit: @JoshVarty)

2 Likes

Hey everyone,

I guess I need to step up my game. So this time I plan on using Partial Convolutions. This means the networks input is both the image with the random cutout and the mask that is generated.
Can anyone help me with this? Is this something that should be implemented as a RandTransform or should I think of something else?

How would I structure my Datablocks with a random Transform returning both the cutout image and the cutout mask?

What I got so far is, to return both the masked image and the mask from the Transform Block. (return (masked,mask)). I am stuck on combining this with the DataBlock:.


dblock=DataBlock(blocks=(ImageBlock(cls=PILImageInput),ImageBlock, ImageBlock),
                get_items=get_image_files, 
                get_x=[get_image_files, noop], # first is the augmented image and one the mask from the other image
                n_inp=2,
                get_y=lambda x: x,# essentially the same file
                splitter=RandomSplitter(),
                item_tfms=[Resize(512),RandomResizedCrop(255, min_scale=0.35), FlipItem(0.5),RandomMask()],
                batch_tfms=[Normalize.from_stats(*imagenet_stats)])
dls = dblock.dataloaders(path/"images", bs=4, path=path)