Can I choose not to apply a transform to my y values in fastai v2?

Occasionally it’s useful to be able to apply a transform only to our inputs (x) and not apply the same transform to our outputs (y). For example, I would like to cut out squares from an image and have a model try to fill in the missing pieces.

My inputs would be the original image minus the cropped portions.

My outputs would be the original image.

I’ve created my databunch as follows where RandomCutout randomly drops pixels in an image:

databunch = data.databunch(pascal_path/'train', 
                           item_tfms=[RandomResizedCrop(460, min_scale=0.75), RandomCutout()], 
                           batch_tfms=[*aug_transforms(size=224, max_warp=0, max_rotate=0)])
# HACK: We're predicting pixel values, so we're just going to predict an output for each RGB channel
databunch.vocab = ['R', 'G', 'B']

However, RandomCutout ends up being applied to both the x and y images, so I end up with batches that look like:

In fastai v1 we were able to restrict transforms by using the tfm_y=False parameter. Is there a corresponding approach in fastai v2?

See tfm_y here:

1 Like

You need to have a different type for your input. I’d suggest creating a ImageTensorInput type for instance and only apply RandomCutout to it.
Another way is to define a transform that takes the tuple as a whole and only applies to the first input.

1 Like

Great idea, thanks!

In my case I’ve been applying the transformation to the PILImage’s one at a time (as opposed to a batch of tensors). So for any future readers, the changes I made were:

Introduce a new PIL Image type:

class PILImageInput(PILImage): pass

Make sure my transform handles that type in encodes()

class RandomCutout(RandTransform):

    def __init__(self, min_n_holes=5, max_n_holes=10, min_length=5, max_length=50, **kwargs):

    def encodes(self, x:PILImageInput):
        # Here we accept only PILImageInput (this type matters in fastai2)

Then just pass this transform to item_tfms when creating a databunch:

databunch = data.databunch(images_path, 
                           item_tfms=[RandomCutout()], # Pass RandomCutout to item_tfms
                           batch_tfms=[*aug_transforms(size=160, max_warp=0, max_rotate=0)])