Hi Bwarner!
Thanks for your reply.
I have found that treating these augmentations as ItemTransform
will indeed produce the expected outcome, despite them not appearing very elegant. Here’s the accompanying code:
from fastai.vision.all import *
@typedispatch
def show_batch(x:Tuple, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
if figsize is None: figsize = (ncols*6, max_n//ncols * 3)
if ctxs is None: ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize)
imgs0, imgs1 = x
print('custom show_batch')
line = imgs0.new_zeros(32, 5, 3) # .vocab[y[0].item()]
for i,ctx in enumerate(ctxs):
torch.cat([imgs0[i].permute(1,2,0), line, imgs1[i].permute(1,2,0)], axis = 1).show(ctx=ctx, title = dls.vocab[y[i].item()])
class TwoImageAugment(ItemTransform):
def __init__(self, tfms1, tfms2):
# super(TwoImageAugment, self).__init__()
self.tfms1 = tfms1 # a list of transforms for the first image
self.tfms2 = tfms2 # a list of transforms for the second image
def encodes(self, x): # (PILImage, PILImage)
# print(x)
img1, img2, y = x # unpack the tuple of images
img1 = ToTensor()(img1)
img2 = ToTensor()(img2)
shape_img = img1.shape
img1 = img1.view(1, *shape_img)
img2 = img2.view(1, *shape_img)
# print(img1.shape)
img1 = Pipeline(self.tfms1)(img1) # apply the first list of transforms to the first image
img2 = Pipeline(self.tfms2)(img2) # apply the second list of transforms to the second image
img1 = img1.saturation(p = 0.9, draw = torch.rand(1)).hue(p = 0.9, draw = torch.rand(1)).brightness(draw=torch.rand(1)*0.5+0.3, p=0.9)
img2 = img2.saturation(p = 0.9, draw = torch.rand(1)).hue(p = 0.9, draw = torch.rand(1)).brightness(draw=torch.rand(1)*0.5+0.3, p=0.9)
img1 = img1.view(*shape_img)
img2 = img2.view(*shape_img)
return (img1*255, img2*255, y) # return a tuple of augmented images
aug_batch_tfms = [IntToFloatTensor(),
RandomResizedCropGPU(size = 32,
min_scale = 0.8,
max_scale = 1,
ratio = (1, 1),
),
Zoom(p = 0.7, min_zoom=0.8, max_zoom=1.2),
Rotate(p = 0.7, max_deg=60),
Flip(p = 0.7),
]
total_augment = TwoImageAugment(aug_batch_tfms, # IntToFloatTensor
aug_batch_tfms)
ds = DataBlock(
blocks = (ImageBlock, ImageBlock, CategoryBlock),
get_items = get_image_files,
splitter = RandomSplitter(valid_pct=0.2, seed=42),
get_y = parent_label,
n_inp=2,
item_tfms = total_augment
# batch_tfms = [Saturation(p = 1., draw = 0.1)]
)
train_path = untar_data(URLs.CIFAR)
device = torch.device('cuda:0')
dls = ds.dataloaders(train_path/'train', bs = 64, num_workers=32, device=device)
# one_batch = dls.one_batch()
dls.show_batch(max_n = 8, figsize = (5, 6.5))
Interestingly, the ‘Lighting transforms,’ such as ‘Brightness’ and ‘Hue,’ did not function as expected when added to aug_batch_tfms
. However, they do work as a method, for example, img1.saturation(p=0.9, draw=torch.rand(1))
. Could you share any insights you may have regarding the reason behind this difference? Your expertise would be greatly appreciated.