Is there a way to prevent transforms from being applied to the labels?

I have a line of code that looks like this (where my_dataset is of type Datasets):

dls = my_dataset.dataloaders(batch_size=64, after_batch=[NormalizeBatch])

(note: I wrote NormalizeBatch myself. It is just a transform that normalizes all elements according to a pre-determined mean and std)

This is currently causing problems, because NormalizeBatch is getting applied not only to the training data, but also to the labels. I think that this is because transforms automatically get applied to everything in the tuple - which in this case is inputs and labels.

Is there a way to prevent a transform from being applied to labels?

I currently have a hacky solution, where I insert code in the normalization to check for the shape of the input, and only proceeds if the shape matches the shape of the inputs. However, I am hoping that there is a better way.

Can you share the code of NormalizeBatch. Transforms uses TypeDispatch, to choose which data the transformation is applied to. If your image and label are of the same type (like TensorImage), this can lead to problems.
Try creating a new type for your labels.

1 Like

As BresNet suggested, you could create a seperate class for your inputs:

class PILImageInput(PILImage):

Set your input’s block to this new class when creating your DataBlock:

blocks = (ImageBlock(PILImageInput), ImageBlock)

And tell your transform to be only applied to instances of that type:

def NormalizeBatch(img):
    if isinstance(img, PILImageInput):
        #Do something to your image
    return img

Hope this helps!

@BobMcDear, this is a good suggestion.
You can further optimize the transform by using @transforms.

def NormalizeBatch(img:PILImageInput):
    # Do something to your image
    return img

This does the same as your function but with typedispatch instead of isinstance. This might not appear very useful in this example but becomes useful if you have multiple classes you want to process differently.

1 Like

Thanks everyone! Indeed, Borna’s method is the one that seems to work. I would also expect your solution to work Keno, but for some reason it does not :frowning: