Siamese Item Transforms

I’m playing around with the siamese twins network and got stuck on the augmentation part. How do I rotate images per Item?

The default dataloader looks like this:

dls = tls.dataloaders(after_item=[Resize(224), ToTensor],
after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])

I’m trying to add Rotation and Brightnes augmentation as item transforms. Because I would like the images to be more different when they are the same image (eg, the image is a bit rotated and darker but it’s still the same image).

I can add the transforms to a batch

dls = tls.dataloaders(after_item=[Resize(224), ToTensor],
after_batch=[IntToFloatTensor, Rotate(max_deg=10), Normalize.from_stats(*imagenet_stats)])

but this will rotate all pair in the whole batch the same.

if I move the transform to “after_item”:

dls = tls.dataloaders(after_item=[ Resize(160), Rotate(max_deg=10) ToTensor],
after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)], bs=100)
I get the error message “Could not do one pass in your dataloader, there is something wrong in it”

What is the easiest way to rotate the images per item?

[solution]
I skipped the fastai transforms and used the Torch transforms instead.

class TorchItemTransforms(Transform):
    "Wrapper to use torch-transforms with fastai framework"

    def __init__(self, transform):
        super().__init__()
        self.trans = transform
        
    def encodes(self, x:(TensorImage, Image.Image)):
        
        print("encoding!")
        return self.trans(x)

and then used it:

torch_tfm = torch.nn.Sequential(
                transforms.RandomRotation(20),
                transforms.ColorJitter(brightness=(0.3,1.2), contrast=(0.2, 1.5), saturation=(0.8,1.2), hue=(-0.05,0.05) ))

dls = tls.dataloaders(after_item=[ RandomResizedCrop2(256), ToTensor, TorchItemTransforms(testtransf)], 
                      after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)], bs=100)

The transforms works fine, but a bit slow…