How to get PackedSequence to work with the mid-level data API?

I have a pandas dataframe with columns ‘sequence’ and ‘label’. The ith element of ‘sequence’ column contains a list of list of floats of size (li, n) where li is variable for each i. When trying to create a DataLoaders from this dataframe, I am running into errors with type retention. A minimum reproducible example:

from fastai.text.all import *
from torch.nn.utils.rnn import pack_sequence

df = pd.DataFrame({
    'sequence': [[[0.1, 0.4], [0.5, 0.8]], 
                [[0.9, 0.2]], 
                [[0.3, 0.7], [0.6, 0.1], [0.4, 0.5]],
                [[1.,2.9], [1.0,2.3], [0.2, 0.2]],
                [[1.1,1.2], [2.4,2.5]]],
    'label': ['class1', 'class2', 'class1', 'class3', 'class2']
})
tfms = [[ColReader('sequence')], [ColReader('label'), Categorize()]]

def pack_input(batch):
    seqs, labels = zip(*batch)
    seqs = [torch.tensor(seq) for seq in seqs]
    seqs= pack_sequence(seqs, enforce_sorted=False)
    return seqs, torch.stack(labels)

dsets = Datasets(df, tfms=tfms, splits=RandomSplitter()(df))
dls = dsets.dataloaders(bs=2, create_batch=pack_input)

This is the full stack trace:

Could not do one pass in your dataloader, there is something wrong in it. Please see the stack trace below:
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[4], line 22
     19     return seqs, torch.stack(labels)
     21 dsets = Datasets(df, tfms=tfms, splits=RandomSplitter()(df))
---> 22 dls = dsets.dataloaders(bs=2, create_batch=pack_input)

File ~/miniforge3/lib/python3.12/site-packages/fastai/data/core.py:333, in FilteredBase.dataloaders(self, bs, shuffle_train, shuffle, val_shuffle, n, path, dl_type, dl_kwargs, device, drop_last, val_bs, **kwargs)
    331 dl = dl_type(self.subset(0), **merge(kwargs,def_kwargs, dl_kwargs[0]))
    332 def_kwargs = {'bs':bs if val_bs is None else val_bs,'shuffle':val_shuffle,'n':None,'drop_last':False}
--> 333 dls = [dl] + [dl.new(self.subset(i), **merge(kwargs,def_kwargs,val_kwargs,dl_kwargs[i]))
    334               for i in range(1, self.n_subsets)]
    335 return self._dbunch_type(*dls, path=path, device=device)

File ~/miniforge3/lib/python3.12/site-packages/fastai/data/core.py:104, in TfmdDL.new(self, dataset, cls, **kwargs)
    102 if not hasattr(self, '_n_inp') or not hasattr(self, '_types'):
    103     try:
--> 104         self._one_pass()
    105         res._n_inp,res._types = self._n_inp,self._types
    106     except Exception as e: 

File ~/miniforge3/lib/python3.12/site-packages/fastai/data/core.py:85, in TfmdDL._one_pass(self)
     84 def _one_pass(self):
---> 85     b = self.do_batch([self.do_item(None)])
     86     if self.device is not None: b = to_device(b, self.device)
     87     its = self.after_batch(b)

File ~/miniforge3/lib/python3.12/site-packages/fastai/data/load.py:185, in DataLoader.do_batch(self, b)
--> 185 def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)

File ~/miniforge3/lib/python3.12/site-packages/fastai/data/load.py:175, in DataLoader.retain(self, res, b)
--> 175 def retain(self, res, b):  return retain_types(res, b[0] if is_listy(b) else b)

File ~/miniforge3/lib/python3.12/site-packages/fasttransform/cast.py:69, in retain_types(new, old, typs)
     67     else: t,typs = typs,None
     68 else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)
---> 69 return t(L(new, old, typs).map_zip(retain_types, cycled=True))

File ~/miniforge3/lib/python3.12/site-packages/fastcore/foundation.py:191, in L.map_zip(self, f, cycled, *args, **kwargs)
--> 191 def map_zip(self, f, *args, cycled=False, **kwargs): return self.zip(cycled=cycled).starmap(f, *args, **kwargs)

File ~/miniforge3/lib/python3.12/site-packages/fastcore/foundation.py:188, in L.starmap(self, f, *args, **kwargs)
--> 188 def starmap(self, f, *args, **kwargs): return self._new(itertools.starmap(partial(f,*args,**kwargs), self))

File ~/miniforge3/lib/python3.12/site-packages/fastcore/foundation.py:113, in L._new(self, items, *args, **kwargs)
--> 113 def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)

File ~/miniforge3/lib/python3.12/site-packages/fastcore/foundation.py:100, in _L_Meta.__call__(cls, x, *args, **kwargs)
     98 def __call__(cls, x=None, *args, **kwargs):
     99     if not args and not kwargs and x is not None and isinstance(x,cls): return x
--> 100     return super().__call__(x, *args, **kwargs)

File ~/miniforge3/lib/python3.12/site-packages/fastcore/foundation.py:108, in L.__init__(self, items, use_list, match, *rest)
    106 def __init__(self, items=None, *rest, use_list=False, match=None):
    107     if (use_list is not None) or not is_array(items):
--> 108         items = listify(items, *rest, use_list=use_list, match=match)
    109     super().__init__(items)

File ~/miniforge3/lib/python3.12/site-packages/fastcore/basics.py:79, in listify(o, use_list, match, *rest)
     77 elif isinstance(o, list): res = o
     78 elif isinstance(o, str) or isinstance(o, bytes) or is_array(o): res = [o]
---> 79 elif is_iter(o): res = list(o)
     80 else: res = [o]
     81 if match is not None:

File ~/miniforge3/lib/python3.12/site-packages/fasttransform/cast.py:69, in retain_types(new, old, typs)
     67     else: t,typs = typs,None
     68 else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)
---> 69 return t(L(new, old, typs).map_zip(retain_types, cycled=True))

File ~/miniforge3/lib/python3.12/site-packages/torch/nn/utils/rnn.py:67, in PackedSequence.__new__(cls, data, batch_sizes, sorted_indices, unsorted_indices)
     64 def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
     65     return super().__new__(
     66         cls,
---> 67         *_packed_sequence_init_args(data, batch_sizes, sorted_indices,
     68                                     unsorted_indices))

File ~/miniforge3/lib/python3.12/site-packages/torch/nn/utils/rnn.py:183, in _packed_sequence_init_args(data, batch_sizes, sorted_indices, unsorted_indices)
    179     return data, batch_sizes, sorted_indices, unsorted_indices
    181 # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
    182 else:
--> 183     assert isinstance(data, (list, tuple)) and len(data) == 2
    184     return data[0], data[1], sorted_indices, unsorted_indices

AssertionError: 

I believe it stems from fastai trying to do type retention on PackedSequence. Can somebody offer any insight?


Here is the dataframe as a table for faster comprehension

Can someone please at least clarify of this is the correct way to pass in a collate function for the dataloaders?