[solved] How to create a generic Augmentation Transforms, that handles training aug and validation aug

I am trying to wrap my head around the MidLevelAPI Specifically creating a custom RandTransform This is the code given for creating a custom:

class AlbumentationsTransform(RandTransform):
    "A transform handler for multiple `Albumentation` transforms"
    split_idx,order=None,2
    def __init__(self, train_aug, valid_aug): store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

How would I modify it to work for segmentation, this particular encodes only takes in PILImage I have tried creating another encode that takes in PILMask but then Albumentations expects us to pass in the image and mask at the same time Something like

class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

But this is now an ItemTransform which the encodes takes in the whole tuple This is my current hacky solution:

class TrainTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask, feature = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"]), feature

class ValidTransform(ItemTransform):
    split_idx = 1
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask, feature = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"]), feature

TLDR: How to have the same code in my hacky solution but with RandTransform instead of ItemTransform

Have you tried adding a 2nd encodes() method in your class AlbumentationsTransform with a different signature?

Like

def encodes(self, img: PILMask):
…

See: Tutorial - Custom transforms | fastai

Types in these methods seem to matter

1 Like

RandTransform are dispatch-dependent, meaning they are applied based on a single type-dispatch and always on a single input. TypeDispatch with groups of items isn’t possible, so this can’t be done.

You have the right appraoch here with Tran and Valid transform.

1 Like

That makes sense. Thank you for the help!

Hi,

Yes, I did…something like:

class AlbumentationsTransform(RandTransform):
    "A transform handler for multiple `Albumentation` transforms"
    split_idx,order=None,2
    def __init__(self, train_aug, valid_aug): store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

    def encodes(self, msk: PILMask):
        if self.idx == 0:
            aug_mask = self.train_aug(mask=np.array(mask))['mask']
        else:
            aug_mask = self.valid_aug(mask=np.array(mask))['mask']
        return PILMask.create(aug_mask)

This solution won’t work due to how Albumentations augments images and masks together:

augs = augmentations(image=np.array(image), mask=np.array(mask))

They need to be passed in together at once so that the mask is transformed in the right way, then you access the individual items as so:

image = augs['image']
mask = augs['mask']

@jimmiemunyi Do you have an example showing your “hacky solution” being used? I have the exact same problem: doing image segmentation, and I want different transforms to be applied in training vs validation. It is not clear to me how I would use the two separate transforms you’re showing there.

This is the code I have. I create a Pandas dataframe with image and mask paths and attributes. I have two getters that perform low-level transformations (pixel bit depth, merging of masks). I want to create two separate sets of transforms for training vs validation. The transform for training needs to be more complex, including image augmentation techniques. My example obviously won’t work because it can’t distinguish between training and validation. How to modify it to use two different transforms?

class AlbumentationsTransform(ItemTransform):
    "A transform handler for multiple `Albumentation` transforms"
    #split_idx,order=0,2

    def __init__(self, train_aug, valid_aug):
        store_attr()

    def encodes(self, x):
        img, mask = x
        if self.split_idx == 0:
            aug = self.train_aug(image=np.array(img), mask=np.array(mask))
        else:
            aug = self.valid_aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])


def get_train_aug():
    return albumentations.Compose(
        [
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.Resize(224, 224),
        ]
    )


def get_valid_aug():
    return albumentations.Compose(
        [
            albumentations.Resize(224, 224),
        ],
    )


item_tfms = [
    Resize((720, 960), method="squish"),
    AlbumentationsTransform(get_train_aug(), get_valid_aug()),
]


def make_mask(row):
    """
    Is called by DataBlock(getters=).
    Takes a list of paths to mask files from a Pandas column.
    Makes sure all masks are 8 bits per pixel.
    If there are multiple masks, merges them.
    Returns a PILMask.create() mask image.
    """
    f = ColReader("mask")
    # PILMask.create() probably forces 8 bits per pixel.
    all_images = [np.asarray(PILMask.create(x)) for x in f(row)]
    image_stack = np.stack(all_images)
    image_union = np.amax(image_stack, axis=0)
    return PILMask.create(image_union)


def make_image(row):
    """
    Receives a Pandas row. Gets an image path from the "image" column.
    Makes sure all images are 8 bits per color channel.
    (There may be multiple color channels.)
    Returns a PILImage.create() image.
    """
    f = ColReader("image")
    # PILImage.create() probably forces 8 bits per color channel.
    image_array = np.asarray(PILImage.create(f(row)))
    return PILImage.create(image_array)


# Most images are 960 x 720. A few images are much larger. So we resize to 960 x 720 first.
# The final resize is to the desired image size for the model.
crop_datablock = DataBlock(
    blocks=(ImageBlock, MaskBlock),
    getters=[make_image, make_mask],
    splitter=TrainTestSplitter(stratify=crop_df["dataset"].to_list()),
    item_tfms=item_tfms,
)

# https://docs.fast.ai/tutorial.albumentations.html#using-different-transform-pipelines-and-the-datablock-api
crop_dataloader = crop_datablock.dataloaders(crop_df, bs=8)
crop_dataloader.train.show_batch()
crop_dataloader.valid.show_batch()

Hey, this was a while ago, let me try and trace down my code then I’ll get back to you

Hello, apologies it took some time. This is how I use my hacky solution. Maybe it will help and give you some insight.

So, I have some augmentations defined, that I want to use on my training set and validation set respectively:

def get_train_aug(): 
    return A.Compose([
            A.Resize(256,256),
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.VerticalFlip(p=0.5),
])

def get_valid_aug(): 
    return A.Compose([A.Resize(256,256)], p=1.)

I then define the Transforms to handle them, using split_idx to help them transform know when to apply it (training vs validation):

class TrainTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask, feature = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"]), feature

class ValidTransform(ItemTransform):
    split_idx = 1
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask, feature = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"]), feature

I then instantiate them:

train_tfms = TrainTransform(get_train_aug())
valid_tfms = ValidTransform(get_valid_aug())

basic_augmentations = [train_tfms, valid_tfms]

And use them with the MidLevel fastai API:

dsets = ... # some code that generates Datasets

dls = dsets.dataloaders(bs=8, after_item=[Resize(256, ToTensor(), 
                        IntToFloafTensor(), *basic_augmentations])

Please note that I wrote this code some time ago and it might not be optimal, but it worked for me at the time. If I were to look at it now I would probably research how best to do it, by looking at the fastai notebooks and other people’s work.

1 Like