Hello,
So I am doing some ultrasound medical imaging model training and I am using this set of augmentations:
def simple_seg_aug():
"""Simple Segmentation Augmentations."""
train = A.Compose([
A.Resize(256,256),
A.HorizontalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.VerticalFlip(p=0.5),])
valid = A.Compose([A.Resize(256,256)], p=1.)
I have these Transforms to handle the transforms:
#| export
class TrainMaskTransform(ItemTransform):
split_idx = 0
def __init__(self, aug): self.aug = aug
def encodes(self, x):
img,mask = x
aug = self.aug(image=np.array(img), mask=np.array(mask))
return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
class ValidMaskTransform(ItemTransform):
split_idx = 1
def __init__(self, aug): self.aug = aug
def encodes(self, x):
img,mask = x
aug = self.aug(image=np.array(img), mask=np.array(mask))
return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
(Not sure if this is the best way to handle it btw, but I’d appreciate some pointers)
The idea above was, when creating a TrainMaskTransform
it should only be applied to the training set, and ditto for the ValidMaskTransform
.
At inference time, I run this code:
inference = Path('inference') # I have some images here
fns_tst = get_image_files(inference)
dl = learn.dls.test_dl(fns_tst) # test_dl
But when I try:
dl.show_batch()
I get this error: (simplified to the last part which is the important one)
File ~/repos/work/auto-ob-torch/auto_ob_torch/dataloaders.py:176, in ValidMaskTransform.encodes(self, x)
175 def encodes(self, x):
--> 176 img,mask = x
177 aug = self.aug(image=np.array(img), mask=np.array(mask))
178 return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
ValueError: not enough values to unpack (expected 2, got 1)
which I kinda understand, since when running inference, our new test_dl
only has the images and no masks so the error arises when unpacking.
Questions
- What set of augmentations is used on the image during test time/inference time? Is it the ones for training or the ones for validation?
- Does the above differ for
test_time_augmentations
? - Any pointers on how to solve the above problem?