Help using dl_tfms in DataBunch

I am developing a custom data augmentation method that requires access to the data at the batch level. Following advice elsewhere in this blog, I wrote a function as described below:

def aug_tfms(b:Collection[Tensor]):
xb, by = b
for i in range (xb.shape[0]):
(Augmentation code)
return [xb, yb]

In the DataBunch definition, I added dl_tfms = aug_tfms, like this:

.databunch(bs=bs, collate_fn=bb_pad_collate, dl_tfms=aug_tfms)

The function aug_tfms does its job creating the augmentation. However, it then applies the augmentation to both the train_dl and valid_dl dataloaders. I want it to apply only to the train_dl. How can accomplish this?

Your help is appreciated.

1 Like

Yeah this bit in v1 doesn’t allow for different transforms for train and valid. You can manually edit the list in data.valid_dl.tfms after the creation.

1 Like