I am trying to wrap my head around the MidLevelAPI Specifically creating a custom RandTransform
This is the code given for creating a custom:
class AlbumentationsTransform(RandTransform):
"A transform handler for multiple `Albumentation` transforms"
split_idx,order=None,2
def __init__(self, train_aug, valid_aug): store_attr()
def before_call(self, b, split_idx):
self.idx = split_idx
def encodes(self, img: PILImage):
if self.idx == 0:
aug_img = self.train_aug(image=np.array(img))['image']
else:
aug_img = self.valid_aug(image=np.array(img))['image']
return PILImage.create(aug_img)
How would I modify it to work for segmentation, this particular encodes only takes in PILImage
I have tried creating another encode that takes in PILMask
but then Albumentations expects us to pass in the image and mask at the same time Something like
class SegmentationAlbumentationsTransform(ItemTransform):
split_idx = 0
def __init__(self, aug): self.aug = aug
def encodes(self, x):
img,mask = x
aug = self.aug(image=np.array(img), mask=np.array(mask))
return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
But this is now an ItemTransform which the encodes takes in the whole tuple This is my current hacky solution:
class TrainTransform(ItemTransform):
split_idx = 0
def __init__(self, aug): self.aug = aug
def encodes(self, x):
img,mask, feature = x
aug = self.aug(image=np.array(img), mask=np.array(mask))
return PILImage.create(aug["image"]), PILMask.create(aug["mask"]), feature
class ValidTransform(ItemTransform):
split_idx = 1
def __init__(self, aug): self.aug = aug
def encodes(self, x):
img,mask, feature = x
aug = self.aug(image=np.array(img), mask=np.array(mask))
return PILImage.create(aug["image"]), PILMask.create(aug["mask"]), feature
TLDR: How to have the same code in my hacky solution but with RandTransform
instead of ItemTransform