Avoid duplicate execution of item_tfms with DataBlock for autoencoder

It seems the canonical way to create a DataBlock for an autoencoder is to use the same block twice, for example:

dblock = DataBlock(blocks=(ImageBlock, ImageBlock),

With this, IIUC, item_tfms will be executed twice. I couldn’t figure out a feasible way to share pre-processing between blocks. The only way I see in the DataBlock API would be to do the shared pre-processing in the get_items function. But that would require loading all data into memory (which is infeasible for large data sets). I could also create my own blocks that share caching data behind the scenes but that seems clunky.

I concluded that to achieve shared pre-processing, I’ll have to use the Mid-Level API (TfmdLists and friends). Is this correct or am I missing something in the DataBlock API?

You can manage this with a transform, something like so:

class ReadAE(ItemTransform):
    order = 100
    def encodes(self, x):
        return (x, x)

pets = DataBlock(blocks=(ImageBlock),
                 batch_tfms=batch_tfms + [ReadAE()], 

The following test will then pass, showing we returned two of the same:

dls = pets.dataloaders(path_im, bs=bs)
batch = next(iter(dls[0]))
test_eq(batch[0], batch[1])

Having the order be very high also then lets us ensure it’s the last transform done, allowing the actual processing to only occur once

1 Like

This works indeed. Thanks!

I actually tried something analog to the following before:

def read_ae(x):
  return (x, x)

pets = DataBlock(blocks=(ImageBlock),
                 batch_tfms=batch_tfms + [read_ae], 

This resulted in one_batch() returning [(x, x)] instead of [x, x] as your version does. I need to read up on ItemTransform… Any hint what is the best starting point?

I tried that as well when I was playing with it. I’m pretty sure the key is the order here.

In regards to the transforms, not really. Probably should make an article at some point but the key thing that helps me is knowing that a transform is not the same as data augmentation, it’s just applying some function to an input or output.

I have a small article here: https://muellerzr.github.io/fastblog/datablock/2020/03/22/TransformFunctions.html which just shows what you can make transforms do, but nothing like you’re wanting I think

For future reference: I think not the order but the usage of ItemTransform – as opposed to a normal Transform – is key here. (A normal Transform is used when a plain function is passed into batch_tfms because it ends up in a Pipeline.)

The documentation for ItemTransform says:

ItemTransform is the class to use to opt out of the default behavior of Transform.

The standard Transform calls encode for each tuple element separately:

res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)

In contrast, ItemTransform calls encode with a list of all elements:

y = getattr(super(), name)(list(x), **kwargs)