Apply transform only to input not target

Found a solution but it feel a bit dirty… Basically since I want the transform to only apply to my X and not my Y, I created a new class so that the type dispatch for my transform only apply it to the X:

class PILImageBWNoised(PILImageBW): pass
class TensorImageBWNoised(TensorImageBW): pass
PILImageBWNoised._tensor_cls = TensorImageBWNoised

class AddNoiseTransform(Transform):
    "Add noise to image"
    order = 11
    def __init__(self, noise_factor=0.3): store_attr(self, 'noise_factor')
    def encodes(self, o:TensorImageBWNoised): return o + (self.noise_factor * torch.randn(*o.shape).to(o.device))

mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBWNoised), ImageBlock(cls=PILImageBW)), 
                 get_items=get_image_files,
                 splitter=RandomSplitter(),
                 batch_tfms=[AddNoiseTransform])

This yield what I want, a X with noise added to it and a clean target:
image

4 Likes