I am developing a custom data augmentation method that requires access to the data at the batch level. Following advice elsewhere in this blog, I wrote a function as described below:
def aug_tfms(b:Collection[Tensor]):
xb, by = b
for i in range (xb.shape[0]):
(Augmentation code)
return [xb, yb]
In the DataBunch definition, I added dl_tfms = aug_tfms, like this:
.databunch(bs=bs, collate_fn=bb_pad_collate, dl_tfms=aug_tfms)
.normalize(imagenet_stats)
The function aug_tfms does its job creating the augmentation. However, it then applies the augmentation to both the train_dl and valid_dl dataloaders. I want it to apply only to the train_dl. How can accomplish this?
Your help is appreciated.