Hello!
I am having a hard time trying to make weighted_dataloaders to work. @sgugger, @muellerzr any ideas? Here is the code:
tfms = [[attrgetter("filename"), PILImage.create],
[attrgetter("target"),Categorize()]]
splits=ColSplitter('is_valid')(df)
dsets = Datasets(df, tfms, splits=splits)
dls = dsets.weighted_dataloaders(wgts=df['wgt'].tolist(),bs=8, source=df,
after_item = [Resize(528, method='squish'), ToTensor()],
after_batch= [IntToFloatTensor(),
*aug_transforms(size=528,
do_flip=True,
max_rotate=15,
max_zoom=1.1,
max_lighting=0.3,
max_warp=0.0,
p_affine=1.0,
p_lighting=1.0
),
Normalize.from_stats(*imagenet_stats)
)
The code work with a normal dataloader, without weights, but with weighted_dataloaders I always get this error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-31-90634fcc3c9e> in <module>
----> 1 dls.show_batch()
/srv/conda/envs/saturn/lib/python3.7/site-packages/fastai2/data/core.py in show_batch(self, b, max_n, ctxs, show, unique, **kwargs)
91 old_get_idxs = self.get_idxs
92 self.get_idxs = lambda: Inf.zeros
---> 93 if b is None: b = self.one_batch()
94 if not show: return self._pre_show_batch(b, max_n=max_n)
95 show_batch(*self._pre_show_batch(b, max_n=max_n), ctxs=ctxs, max_n=max_n, **kwargs)
/srv/conda/envs/saturn/lib/python3.7/site-packages/fastai2/data/load.py in one_batch(self)
129 def one_batch(self):
130 if self.n is not None and len(self)==0: raise ValueError(f'This DataLoader does not contain any batches')
--> 131 with self.fake_l.no_multiproc(): res = first(self)
132 if hasattr(self, 'it'): delattr(self, 'it')
133 return res
/srv/conda/envs/saturn/lib/python3.7/site-packages/fastcore/utils.py in first(x)
174 def first(x):
175 "First element of `x`, or None if missing"
--> 176 try: return next(iter(x))
177 except StopIteration: return None
178
/srv/conda/envs/saturn/lib/python3.7/site-packages/fastai2/data/load.py in __iter__(self)
95 self.randomize()
96 self.before_iter()
---> 97 for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
98 if self.device is not None: b = to_device(b, self.device)
99 yield self.after_batch(b)
/srv/conda/envs/saturn/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __init__(self, loader)
379
380 self._dataset_fetcher = _DatasetKind.create_fetcher(
--> 381 self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
382
383 def _next_data(self):
/srv/conda/envs/saturn/lib/python3.7/site-packages/torch/utils/data/dataloader.py in create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last)
39 return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
40 else:
---> 41 return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
42
43
/srv/conda/envs/saturn/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in __init__(self, dataset, auto_collation, collate_fn, drop_last)
19 def __init__(self, dataset, auto_collation, collate_fn, drop_last):
20 super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
---> 21 self.dataset_iter = iter(dataset)
22
23 def fetch(self, possibly_batched_index):
/srv/conda/envs/saturn/lib/python3.7/site-packages/fastai2/data/load.py in __iter__(self)
25 store_attr(self, 'd,pin_memory,num_workers,timeout')
26
---> 27 def __iter__(self): return iter(self.d.create_batches(self.d.sample()))
28
29 @property
/srv/conda/envs/saturn/lib/python3.7/site-packages/fastai2/data/load.py in sample(self)
89
90 def sample(self):
---> 91 idxs = self.get_idxs()
92 return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs)
93
/srv/conda/envs/saturn/lib/python3.7/site-packages/fastai2/callback/data.py in get_idxs(self)
23 if self.n==0: return []
24 if not self.shuffle: return super().get_idxs()
---> 25 return list(np.random.choice(self.n, self.n, p=self.wgts))
26
27 # Cell
mtrand.pyx in numpy.random.mtrand.RandomState.choice()
ValueError: 'a' and 'p' must have same size