Which sets of augmentations are applied during inference?

Hello,

So I am doing some ultrasound medical imaging model training and I am using this set of augmentations:

def simple_seg_aug():
    """Simple Segmentation Augmentations."""
    train = A.Compose([
            A.Resize(256,256),
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.VerticalFlip(p=0.5),])
    
    valid = A.Compose([A.Resize(256,256)], p=1.)

I have these Transforms to handle the transforms:

#| export
class TrainMaskTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

class ValidMaskTransform(ItemTransform):
    split_idx = 1
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

(Not sure if this is the best way to handle it btw, but I’d appreciate some pointers)
The idea above was, when creating a TrainMaskTransform it should only be applied to the training set, and ditto for the ValidMaskTransform.

At inference time, I run this code:

inference = Path('inference') # I have some images here
fns_tst = get_image_files(inference)
dl = learn.dls.test_dl(fns_tst) # test_dl

But when I try:

dl.show_batch()

I get this error: (simplified to the last part which is the important one)

File ~/repos/work/auto-ob-torch/auto_ob_torch/dataloaders.py:176, in ValidMaskTransform.encodes(self, x)
    175 def encodes(self, x):
--> 176     img,mask = x
    177     aug = self.aug(image=np.array(img), mask=np.array(mask))
    178     return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

ValueError: not enough values to unpack (expected 2, got 1)

which I kinda understand, since when running inference, our new test_dl only has the images and no masks so the error arises when unpacking.

Questions

  1. What set of augmentations is used on the image during test time/inference time? Is it the ones for training or the ones for validation?
  2. Does the above differ for test_time_augmentations?
  3. Any pointers on how to solve the above problem?