Error with learn.tta when passing TfmdDL

Hi everyone,

I have created a TfmdLists object with a custom Transform. I use this to create a DataLoaders object and I train a learner on this DataLoaders. The show batch, training and get preds all work as they are supposed to. What I want to do is perform test time augmentation. When I do

learn.tta(dl=dls.valid)

I encounter this error

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[43], line 1
----> 1 learn.tta(dl=dls.valid)

File /opt/conda/lib/python3.10/site-packages/fastai/learner.py:659, in tta(self, ds_idx, dl, n, item_tfms, batch_tfms, beta, use_max)
    657 try:
    658     self(_before_epoch)
--> 659     with dl.dataset.set_split_idx(0), self.no_mbar():
    660         if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n)))
    661         aug_preds = []

File /opt/conda/lib/python3.10/site-packages/fastcore/basics.py:496, in GetAttr.__getattr__(self, k)
    494 if self._component_attr_filter(k):
    495     attr = getattr(self,self._default,None)
--> 496     if attr is not None: return getattr(attr,k)
    497 raise AttributeError(k)

File /opt/conda/lib/python3.10/site-packages/fastcore/transform.py:212, in Pipeline.__getattr__(self, k)
--> 212 def __getattr__(self,k): return gather_attrs(self, k, 'fs')

File /opt/conda/lib/python3.10/site-packages/fastcore/transform.py:173, in gather_attrs(o, k, nm)
    171 att = getattr(o,nm)
    172 res = [t for t in att.attrgot(k) if t is not None]
--> 173 if not res: raise AttributeError(k)
    174 return res[0] if len(res)==1 else L(res)

AttributeError: set_split_idx

Link to minimum reproducible example: TTA error with TfmdDL | Kaggle

Looking at the stack trace, I understand that tta does

dl.dataset.set_split_idx(0)

In my case, the dl is a TfmdDL and dl.dataset yields a TfmdLists object. TfmdLists doesn’t have any ‘set_split_idx’ attribute. So, what I want to know is, is there any way of doing tta with a TfmdDL? Any help or advice is much appreciated thank you!