Bug in custom Transform for a self-supervised model

I am trying to create a self-supervised model where the target is generated by the input. This model (while not using text data) uses a similar structure to BERT. This means at every batch I would perform some form of random masking and then generate the target variable. As a result, I’m trying to write a custom Transform class that takes a single tensor as input and outputs two tensors, and then plug it in as the after_batch when creating the Dataloaders. This is because I know the _split function in Learner class will need a tuple of size (at least) 2 to generate my xb and yb. The problem is, the output has been wrapped inside another tuple, turning it into a tuple of size 1, not 2. Any help on this is massively appreciated!
Here is a simplified example code that produces the problem:

class customtransform(DisplayedTransform):
    def __init__(self,dummy=0): store_attr()
    def encodes(self,o:Tensor):
        xb = o
        yb = o
        return (xb,yb)

dls = DataLoaders.from_dsets(train_ds, valid_ds,after_batch=[customtransform()],bs=4)

(1, 2)

So b = ((x,x),) is a length 1 tuple where the only item is a tuple. This tuple has a length of 2 and is the xb and yb that I want to send to the Learner class. However, because it is wrapped by that top level useless tuple, the _split method in Learner just doesn’t like it.
I tried digging into how the after_batch parameter gets used in DataLoaders in the hope to find a fix for my problem, but unfortunately, the after_batch parameter is too much for my beginner knowledge in python to handle. Can someone please help me with this problem?

I also want to call out that I’m happy to self-learn and try to solve the problem myself. If someone can show me how after_batch in dataloaders actually get implemented in the code that is really helpful. I know that after_batch is a Pipeline object under the dataloader, and I can see in the iter method of DataLoader class there is a yield that calls after_batch, but when I read the pipeline code I can’t understand when and how my customtransform actually gets executed. If someone can just give me a pointer that would be super helpful.
In particular, I am trying to figure out why if I create a TensorDataset with two tensors (and no custom transform) I end up getting len(b) = 2, but in the example in the main post I get len(b) = 1.

I have a solution which is to write a brand new subclass of learner and change the definition of _split method. Still interested to hear more about this mysterious pipeline class works specifically with after_batch in dataloader