Hi,
I finally spot the problems that the function has (YAY!). As a summary if someone would like to implement it:
- Transforms on GPU are on batch so the shape of the tensor is like 4x3x244x244. In this specific case, the function should change values in the
x[;,channel,None,None]
. This is easily spot if you run dblock.summary(path)
and see the last part when batch is built.
- It is necessary to input and output
TensorImage
for the function to work with the rest of batch transforms.
The code if someone wants to use it:
@patch
def rgb_randomize(x:TensorImage, channel:int=None, thresh:float=0.3, p=0.5):
"Randomize one of the channels of the input image"
if channel is None: channel = np.random.randint(0, x.shape[1])
x[:,channel,None,None] = torch.rand(x[:,channel,None,None].shape) * np.random.uniform(0, thresh)
return TensorImage(x)
class rgb_transform(RandTransform):
order=10
def __init__(self, channel=None, thresh=0.3, p=0.5, **kwargs):
super().__init__(p=p)
self.channel,self.thresh,self.p = channel,thresh,p
def encodes(self, x:TensorImage): return TensorImage(x).rgb_randomize(channel=self.channel,thresh=self.thresh, p=self.p )
Then you can instantiate it in your batch transforms like:
batch_tfms=[*aug_transforms(flip_vert=True), rgb_transform(thresh=0.99, p=1)]
And run Datablock
as usual.
I would like to thank @sgugger for the infinite help on this. I will suggest a PR as soon as possible.