Let’s say I want to do a denoising autoencoder in fastai v2, I want to add random noise to my input, but I don’t want it applied to my target. I thought the transforms would be the best place to add that:
class AddNoiseTransform(Transform):
"Add noise to image"
order = 11
def __init__(self, noise_factor=0.3): store_attr(self, 'noise_factor')
def encodes(self, o:TensorImage): return o + (self.noise_factor * torch.randn(*o.shape).to(o.device))
mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), ImageBlock(cls=PILImageBW)),
get_items=get_image_files,
splitter=RandomSplitter(),
batch_tfms=[AddNoiseTransform])
But then when I do show_batch() on my DataSets, I see the transform applied to both input and target:
I could create a subclass of TensorImage to represent my target so that my transform is not applied to it… But there must be a better way? How can I apply a transform only to the input, not the target?
Give it a split_idx property of 0. (There are some examples in the vision augmentation file I believe with this property). As you can imagine, 0 is train only, 1 is validation only, none is both
Found a solution but it feel a bit dirty… Basically since I want the transform to only apply to my X and not my Y, I created a new class so that the type dispatch for my transform only apply it to the X:
class PILImageBWNoised(PILImageBW): pass
class TensorImageBWNoised(TensorImageBW): pass
PILImageBWNoised._tensor_cls = TensorImageBWNoised
class AddNoiseTransform(Transform):
"Add noise to image"
order = 11
def __init__(self, noise_factor=0.3): store_attr(self, 'noise_factor')
def encodes(self, o:TensorImageBWNoised): return o + (self.noise_factor * torch.randn(*o.shape).to(o.device))
mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBWNoised), ImageBlock(cls=PILImageBW)),
get_items=get_image_files,
splitter=RandomSplitter(),
batch_tfms=[AddNoiseTransform])
This yield what I want, a X with noise added to it and a clean target:
I think you will need to use fastai-v2’s “mid-level API” for this, which will give you more flexibility over the higher-level DataBlock API. Have a look at this chapter in fastbook, specifically the part about Datasets that allows specification of x_tfms and y_tfms.
But I guess if you already have something that works, can just stick with that?