Metadata of a Tensor Subclass not available in `show_batch`

I’m having trouble to access metadata of a Tensor subclass in show_batch():

class T(TensorBase):
    def create(cls, data, metadata):
        return cls(data, metadata=metadata)

def show_batch( x: T, y, samples, ctxs=None, figsize=None, **kwargs):
    print("Metadata:", samples[0][0].metadata)

t = T.create([1,2,3], metadata=1)
assert t.metadata == 1

dls = DataLoaders.from_dsets([t, t], [t, t], bs=2)

Gives an AttributeError: 'T' object has no attribute 'metadata' in my defined show_batch.

Previous to torch 1.9.0, list(b) would return a list of Tensors, not Ts. That’s why retain_types behaves differently and does not copy metadata. If I patch batch_to_samples, so that it does not use list(b) but [b[i] for i in range(..), it seems to work:

def batch_to_samples(b, max_n=10):
    "'Transposes' a batch to (at most `max_n`) samples"
    if isinstance(b, Tensor):
        return retain_types([b[i] for i in range(min(max_n, len(b)))], [b])
        res = L(b).map(partial(batch_to_samples,max_n=max_n))
        return retain_types(, [b])

def _decode_batch(self, b, max_n=9, full=True):
    f = self.after_item.decode
    f1 = self.before_batch.decode
    f = compose(f1, f, partial(getattr(self.dataset,'decode',noop), full = full))
    return L(batch_to_samples(b, max_n=max_n)).map(f)
dls = DataLoaders.from_dsets([t, t], [t, t], bs=2)


Metadata: 1