I think the problem is that the transform is being type-dispatched, so since both x and y are images, it is being applied to both. One alternative is to have a separate type for y images so it doesn’t get applied there, but that may be too complicated. I think this solution is probably easier:
Why it works is that it redefines __call__
which usually checks for the type of the data and applies based on the type dispatch, but here will be applied however you define it.