I am trying to add ColorJitter https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html to aug_transforms but I am running into some erros.
I am trying
custom = aug_transforms(size=128, min_scale=0.75)
custom=[transforms.ColorJitter(brightness=.5, hue=.3)]+custom
def get_dls(bs, size):
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=[*custom,
Normalize.from_stats(*imagenet_stats)])
return dblock.dataloaders(path, bs=bs)
dls = get_dls(64, 224)
I get the error
TypeError: Tensor is not a torch image.
Any inputs on how I can add pytorch transforms to aug_transforms and create dataloaders?