[solved] How to create a generic Augmentation Transforms, that handles training aug and validation aug

I am trying to wrap my head around the MidLevelAPI Specifically creating a custom RandTransform This is the code given for creating a custom:

class AlbumentationsTransform(RandTransform):
    "A transform handler for multiple `Albumentation` transforms"
    split_idx,order=None,2
    def __init__(self, train_aug, valid_aug): store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

How would I modify it to work for segmentation, this particular encodes only takes in PILImage I have tried creating another encode that takes in PILMask but then Albumentations expects us to pass in the image and mask at the same time Something like

class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

But this is now an ItemTransform which the encodes takes in the whole tuple This is my current hacky solution:

class TrainTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask, feature = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"]), feature

class ValidTransform(ItemTransform):
    split_idx = 1
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask, feature = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"]), feature

TLDR: How to have the same code in my hacky solution but with RandTransform instead of ItemTransform

Have you tried adding a 2nd encodes() method in your class AlbumentationsTransform with a different signature?

Like

def encodes(self, img: PILMask):
…

See: Tutorial - Custom transforms | fastai

Types in these methods seem to matter

1 Like

RandTransform are dispatch-dependent, meaning they are applied based on a single type-dispatch and always on a single input. TypeDispatch with groups of items isn’t possible, so this can’t be done.

You have the right appraoch here with Tran and Valid transform.

1 Like

That makes sense. Thank you for the help!

Hi,

Yes, I did…something like:

class AlbumentationsTransform(RandTransform):
    "A transform handler for multiple `Albumentation` transforms"
    split_idx,order=None,2
    def __init__(self, train_aug, valid_aug): store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

    def encodes(self, msk: PILMask):
        if self.idx == 0:
            aug_mask = self.train_aug(mask=np.array(mask))['mask']
        else:
            aug_mask = self.valid_aug(mask=np.array(mask))['mask']
        return PILMask.create(aug_mask)

This solution won’t work due to how Albumentations augments images and masks together:

augs = augmentations(image=np.array(image), mask=np.array(mask))

They need to be passed in together at once so that the mask is transformed in the right way, then you access the individual items as so:

image = augs['image']
mask = augs['mask']