Great idea, thanks!
In my case I’ve been applying the transformation to the PILImage’s one at a time (as opposed to a batch of tensors). So for any future readers, the changes I made were:
Introduce a new PIL Image type:
class PILImageInput(PILImage): pass
Make sure my transform handles that type in encodes()
class RandomCutout(RandTransform):
def __init__(self, min_n_holes=5, max_n_holes=10, min_length=5, max_length=50, **kwargs):
super().__init__(**kwargs)
....
def encodes(self, x:PILImageInput):
# Here we accept only PILImageInput (this type matters in fastai2)
...
Then just pass this transform to item_tfms
when creating a databunch:
databunch = data.databunch(images_path,
bs=10,
item_tfms=[RandomCutout()], # Pass RandomCutout to item_tfms
batch_tfms=[*aug_transforms(size=160, max_warp=0, max_rotate=0)])