I’m having trouble to access metadata of a Tensor subclass in show_batch():
class T(TensorBase):
@classmethod
def create(cls, data, metadata):
return cls(data, metadata=metadata)
@typedispatch
def show_batch( x: T, y, samples, ctxs=None, figsize=None, **kwargs):
print("Metadata:", samples[0][0].metadata)
t = T.create([1,2,3], metadata=1)
assert t.metadata == 1
dls = DataLoaders.from_dsets([t, t], [t, t], bs=2)
dls.show_batch()
Gives an AttributeError: 'T' object has no attribute 'metadata'
in my defined show_batch
.
Previous to torch 1.9.0, list(b)
would return a list of Tensor
s, not T
s. That’s why retain_types behaves differently and does not copy metadata. If I patch batch_to_samples
, so that it does not use list(b)
but [b[i] for i in range(..)
, it seems to work:
def batch_to_samples(b, max_n=10):
"'Transposes' a batch to (at most `max_n`) samples"
if isinstance(b, Tensor):
return retain_types([b[i] for i in range(min(max_n, len(b)))], [b])
else:
res = L(b).map(partial(batch_to_samples,max_n=max_n))
return retain_types(res.zip(), [b])
@patch_to(TfmdDL)
def _decode_batch(self, b, max_n=9, full=True):
f = self.after_item.decode
f1 = self.before_batch.decode
f = compose(f1, f, partial(getattr(self.dataset,'decode',noop), full = full))
return L(batch_to_samples(b, max_n=max_n)).map(f)
dls = DataLoaders.from_dsets([t, t], [t, t], bs=2)
dls.show_batch()
Output:
Metadata: 1