Weighted sampler in Fastai2

Hi guys,
In fastai_v1 we could change the sampler using something like this:

sampler = WeightedRandomSampler(weights=weights,num_samples=len(weights))

data_cls.train_dl = data_cls.train_dl.new(shuffle=False,sampler=sampler)

Is it possible to do the same in fastai_v2? The docs describe a WeightedDL but I keep getting an error when using it::

dls = dsets.weighted_dataloaders(wgts=weights,bs=64,seq_len=30)
x, y = first(dls.train)

RuntimeError                              Traceback (most recent call last)
<ipython-input-16-2a77872a224f> in <module>
----> 1 x,y = first(dls.train)

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastcore/utils.py in first(x)
    220 def first(x):
    221     "First element of `x`, or None if missing"
--> 222     try: return next(iter(x))
    223     except StopIteration: return None
    224 

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/data/load.py in __iter__(self)
    100         self.before_iter()
    101         self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
--> 102         for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
    103             if self.device is not None: b = to_device(b, self.device)
    104             yield self.after_batch(b)

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    361 
    362     def __next__(self):
--> 363         data = self._next_data()
    364         self._num_yielded += 1
    365         if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    987             else:
    988                 del self._task_info[idx]
--> 989                 return self._process_data(data)
    990 
    991     def _try_put_index(self):

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1012         self._try_put_index()
   1013         if isinstance(data, ExceptionWrapper):
-> 1014             data.reraise()
   1015         return data
   1016 

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/_utils.py in reraise(self)
    393             # (https://bugs.python.org/issue2651), so we work around it.
    394             msg = KeyErrorMessage(msg)
--> 395         raise self.exc_type(msg)

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 34, in fetch
    data = next(self.dataset_iter)
  File "/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/data/load.py", line 111, in create_batches
    yield from map(self.do_batch, self.chunkify(res))
  File "/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/data/load.py", line 132, in do_batch
    def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
  File "/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/data/load.py", line 131, in create_batch
    def create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b)
  File "/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/data/load.py", line 48, in fa_collate
    else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
  File "/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/data/load.py", line 48, in <listcomp>
    else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
  File "/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/data/load.py", line 47, in fa_collate
    return (default_collate(t) if isinstance(b, _collate_types)
  File "/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [41] at entry 0 and [34] at entry 1
2 Likes