Perform different augmentations on each imageblock

Hello everyone,

I am using fastai to create a dataset that includes 2 identical images and a label in a single sample. I want to perform different augmentations on each image. However, my current code applies the same augmentation to each image. How can I address this problem? Here is my code:

aug_batch_tfms = [Contrast(p=0.9, draw=0.8)] 

ds = DataBlock(
    blocks=(ImageBlock, ImageBlock, CategoryBlock),
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    batch_tfms=[*aug_batch_tfms, Normalize.from_stats(*cifar_stats)]

train_path = untar_data(URLs.CIFAR)
device = torch.device('cuda:0')

dls = ds.dataloaders(train_path/'train', bs=64, num_workers=0, device=device)

Unfortunately, you’ve discovered one of the primary weaknesses of fastai transforms: they try to apply the same transform on all inputs. This works great for image segmentation, where we want our label masks to match the images.

There is a workaround. Instead of applying the transforms in the after_batch method of the fastai DataLoader, pass them to a callback and apply them there. My CutMixUpAugment callback does this to support different levels of augmentations for MixUp vs non-MixUp images.

Here’s some pseudocode for how you’d do it.

def __init__(...)
    # convert a list of augmentations into a fastai transforms Pipeline
    self.aug_pipe = Pipeline(*aug_batch_tfms)

def before_batch(self):
    # apply the different augmentations from the same set of augs
    # to each batch of images. Remember that fastai batches are tuples,
    # so you'll need to detuplify them and then retuplify them before
    # reassigning them to the batch (learn.xb)
    self.learn.xb = (self.aug_pipe(self.xb[0]), self.aug_pipe(self.xb[1]))

The source for CutMixUpAugment is here, if you want to look at a working example.

You could also have two pipelines of augmentations with differing sets of augmentations to apply too.

1 Like

Hi Bwarner!
Thanks for your reply.

I have found that treating these augmentations as ItemTransform will indeed produce the expected outcome, despite them not appearing very elegant. Here’s the accompanying code:

from import *

def show_batch(x:Tuple, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    if figsize is None: figsize = (ncols*6, max_n//ncols * 3)
    if ctxs is None: ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize)
    imgs0, imgs1 = x
    print('custom show_batch')
    line = imgs0.new_zeros(32, 5, 3) # .vocab[y[0].item()]
    for i,ctx in enumerate(ctxs):[imgs0[i].permute(1,2,0), line, imgs1[i].permute(1,2,0)], axis = 1).show(ctx=ctx, title = dls.vocab[y[i].item()])

class TwoImageAugment(ItemTransform):
    def __init__(self, tfms1, tfms2):
        # super(TwoImageAugment, self).__init__()
        self.tfms1 = tfms1 # a list of transforms for the first image
        self.tfms2 = tfms2 # a list of transforms for the second image
    def encodes(self, x):  # (PILImage, PILImage)
        # print(x)
        img1, img2, y = x # unpack the tuple of images
        img1 = ToTensor()(img1) 
        img2 = ToTensor()(img2)
        shape_img = img1.shape
        img1 = img1.view(1, *shape_img)
        img2 = img2.view(1, *shape_img)
        # print(img1.shape)
        img1 = Pipeline(self.tfms1)(img1) # apply the first list of transforms to the first image
        img2 = Pipeline(self.tfms2)(img2) # apply the second list of transforms to the second image
        img1 = img1.saturation(p = 0.9, draw = torch.rand(1)).hue(p = 0.9, draw = torch.rand(1)).brightness(draw=torch.rand(1)*0.5+0.3, p=0.9)
        img2 = img2.saturation(p = 0.9, draw = torch.rand(1)).hue(p = 0.9, draw = torch.rand(1)).brightness(draw=torch.rand(1)*0.5+0.3, p=0.9)

        img1 = img1.view(*shape_img)
        img2 = img2.view(*shape_img)
        return (img1*255, img2*255, y) # return a tuple of augmented images

aug_batch_tfms = [IntToFloatTensor(), 
                    RandomResizedCropGPU(size = 32, 
                        min_scale = 0.8, 
                        max_scale = 1, 
                        ratio = (1, 1), 
                    Zoom(p = 0.7, min_zoom=0.8, max_zoom=1.2),
                    Rotate(p = 0.7, max_deg=60),
                    Flip(p = 0.7),

total_augment = TwoImageAugment(aug_batch_tfms,  # IntToFloatTensor

ds = DataBlock(
    blocks = (ImageBlock, ImageBlock, CategoryBlock),
    get_items = get_image_files,
    splitter = RandomSplitter(valid_pct=0.2, seed=42),
    get_y = parent_label,
    item_tfms = total_augment
    # batch_tfms = [Saturation(p = 1., draw = 0.1)]

train_path = untar_data(URLs.CIFAR)
device = torch.device('cuda:0')

dls = ds.dataloaders(train_path/'train', bs = 64, num_workers=32, device=device)
# one_batch = dls.one_batch()
dls.show_batch(max_n = 8, figsize = (5, 6.5))

Interestingly, the ‘Lighting transforms,’ such as ‘Brightness’ and ‘Hue,’ did not function as expected when added to aug_batch_tfms. However, they do work as a method, for example, img1.saturation(p=0.9, draw=torch.rand(1)) . Could you share any insights you may have regarding the reason behind this difference? Your expertise would be greatly appreciated.