I have a line of code that looks like this (where
my_dataset is of type
dls = my_dataset.dataloaders(batch_size=64, after_batch=[NormalizeBatch])
(note: I wrote
NormalizeBatch myself. It is just a transform that normalizes all elements according to a pre-determined mean and std)
This is currently causing problems, because
NormalizeBatch is getting applied not only to the training data, but also to the labels. I think that this is because transforms automatically get applied to everything in the tuple - which in this case is inputs and labels.
Is there a way to prevent a transform from being applied to labels?
I currently have a hacky solution, where I insert code in the normalization to check for the shape of the input, and only proceeds if the shape matches the shape of the inputs. However, I am hoping that there is a better way.