Add_test passes valid transforms to DeviceDataLoader

Hi!

I was wondering why my test images where getting resized while I was passing no transforms to add_test and I noticed that it was implemented like this:

    def add_test(self, items:Iterator, label:Any=None, tfms=None, tfm_y=None)->None:
        "Add the `items` as a test set. Pass along `label` otherwise label them with `EmptyLabel`."
        self.label_list.add_test(items, label=label, tfms=tfms, tfm_y=tfm_y)
        vdl = self.valid_dl
        dl = DataLoader(self.label_list.test, vdl.batch_size, shuffle=False, drop_last=False, num_workers=vdl.num_workers)
        self.test_dl = DeviceDataLoader(dl, vdl.device, vdl.tfms, vdl.collate_fn)

The DeviceDataLoader(dl, vdl.device, vdl.tfms, vdl.collate_fn) seems to explain my problem, therefore I was wondering if it was intentional that it passes vdl.tfms instead of tfms. If not, I can make a PR to correct this (that will require some very hard work as you can see).

vdl.tfms are the validation transforms on the validation dataloader (typically normalization). There is no resize there as you are already on the GPU and batched by that point.

Oh right, that makes sense! I’ll keep investigating as to why my test images are getting resized then.

Found it!

def new(self, x, y, tfms=None, tfm_y=None, **kwargs)->'LabelList':
    tfms,tfm_y = ifnone(tfms, self.tfms),ifnone(tfm_y, self.tfm_y)
    if isinstance(x, ItemList):
        return self.__class__(x, y, tfms=tfms, tfm_y=tfm_y, **self.tfmargs)
    else:
        return self.new(self.x.new(x, **kwargs), self.y.new(y, **kwargs), tfms=tfms, tfm_y=tfm_y).process()

As self.tfmargs is passed to the constructor, the argument size that I gave when creating the ItemList gets passed to the test list. After that, even though the list of transforms is empty, apply_tfms includes a resize part, meaning my test images get resized even when I don’t want to.
I’m not sure it would be beneficial to also pass custom tfmargs through add_test, as it is not very hard to patch when I want to (learner.data.label_list.test.tfmargs = None).

1 Like