Found a solution but it feel a bit dirty… Basically since I want the transform to only apply to my X and not my Y, I created a new class so that the type dispatch for my transform only apply it to the X:
class PILImageBWNoised(PILImageBW): pass
class TensorImageBWNoised(TensorImageBW): pass
PILImageBWNoised._tensor_cls = TensorImageBWNoised
class AddNoiseTransform(Transform):
"Add noise to image"
order = 11
def __init__(self, noise_factor=0.3): store_attr(self, 'noise_factor')
def encodes(self, o:TensorImageBWNoised): return o + (self.noise_factor * torch.randn(*o.shape).to(o.device))
mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBWNoised), ImageBlock(cls=PILImageBW)),
get_items=get_image_files,
splitter=RandomSplitter(),
batch_tfms=[AddNoiseTransform])
This yield what I want, a X with noise added to it and a clean target: