I implemented a custom Dataset for multi-class classification but get error:
ValueError: Expected input batch_size (64) to match target batch_size (640).
I noticed that the target batch size is off by the number of classes, so I’m assuming that I’m missing something obvious with in the __getitem__
function. I don’t know how to diagnose the problem or how to tweak the returns from my custom Dataset class. Anyone recommendations for how to proceed?
Here’s some other data points. In custom Dataset class:
.n_inp: 1
.c: 10
getitem return: ( torch.Size([3, 8, 8]), torch.Size([1, 10]) )
Label: torch.Size([1, 10])
Other:
batch size: 64
Code snippets:
class MyDataset(Dataset):
"""My dataset."""
def __init__(self, csv_file):
"""
Args:
csv_file (string): Path to the csv file with annotations and FEN.
"""
self.df = pd.read_csv(csv_file)
# create the labels
mlb = MultiLabelBinarizer()
res = pd.DataFrame(mlb.fit_transform(self.df.Labels.apply(lambda x: x.split(' ')).tolist()),
columns=mlb.classes_,
index=self.df.index)
self.df = pd.concat([self.df, res], axis=1)
self.c = len(mlb.classes_)
self.classes = mlb.classes_
self.n_inp = 1
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
img = createImg(self.df.iloc[idx]['inputString'])
output = transforms.ToTensor()(img).unsqueeze_(0).squeeze().float()
labels = torch.as_tensor(self.df.iloc[idx][self.classes].tolist()).unsqueeze_(0)
return (output, labels)
data = MyDataset(dtapath)
train_loader = DataLoader(data, batch_size=batch_size,
sampler=train_sampler)
validation_loader = DataLoader(data, batch_size=batch_size,
sampler=valid_sampler)
dls = DataLoaders(train_loader, validation_loader, device='cpu')
learn = cnn_learner(dls,resnet50, loss_func=CrossEntropyLossFlat(), n_out=len(validThemes), normalize=False, metrics=[accuracy_multi])
learn.lr_find()
Full stacktrace below:
ValueError Traceback (most recent call last)
<ipython-input-136-d81c6bd29d71> in <module>()
----> 1 learn.lr_find()
20 frames
/usr/local/lib/python3.6/dist-packages/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
220 n_epoch = num_it//len(self.dls.train) + 1
221 cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 222 with self.no_logging(): self.fit(n_epoch, cbs=cb)
223 if show_plot: self.recorder.plot_lr_find()
224 if suggestions:
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
209 self.opt.set_hypers(lr=self.lr if lr is None else lr)
210 self.n_epoch = n_epoch
--> 211 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
212
213 def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_fit(self)
200 for epoch in range(self.n_epoch):
201 self.epoch=epoch
--> 202 self._with_events(self._do_epoch, 'epoch', CancelEpochException)
203
204 def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_epoch(self)
194
195 def _do_epoch(self):
--> 196 self._do_epoch_train()
197 self._do_epoch_validate()
198
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_epoch_train(self)
186 def _do_epoch_train(self):
187 self.dl = self.dls.train
--> 188 self._with_events(self.all_batches, 'train', CancelTrainException)
189
190 def _do_epoch_validate(self, ds_idx=1, dl=None):
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in all_batches(self)
164 def all_batches(self):
165 self.n_iter = len(self.dl)
--> 166 for o in enumerate(self.dl): self.one_batch(*o)
167
168 def _do_one_batch(self):
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in one_batch(self, i, b)
182 self.iter = i
183 self._split(b)
--> 184 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
185
186 def _do_epoch_train(self):
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
158
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_one_batch(self)
170 self('after_pred')
171 if len(self.yb):
--> 172 self.loss_grad = self.loss_func(self.pred, *self.yb)
173 self.loss = self.loss_grad.clone()
174 self('after_loss')
/usr/local/lib/python3.6/dist-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
33 if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
34 if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 35 return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
36
37 # Cell
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
960 def forward(self, input: Tensor, target: Tensor) -> Tensor:
961 return F.cross_entropy(input, target, weight=self.weight,
--> 962 ignore_index=self.ignore_index, reduction=self.reduction)
963
964
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2463 cross_entropy, tens_ops, input, target, weight=weight,
2464 size_average=size_average, ignore_index=ignore_index, reduce=reduce,
-> 2465 reduction=reduction)
2466 if size_average is not None or reduce is not None:
2467 reduction = _Reduction.legacy_get_string(size_average, reduce)
/usr/local/lib/python3.6/dist-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
1061 # Use `public_api` instead of `implementation` so __torch_function__
1062 # implementations can do equality/identity comparisons.
-> 1063 result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
1064
1065 if result is not NotImplemented:
/usr/local/lib/python3.6/dist-packages/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
323 convert=False
324 if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 325 res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
326 if convert: res = convert(res)
327 if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)
/usr/local/lib/python3.6/dist-packages/torch/tensor.py in __torch_function__(cls, func, types, args, kwargs)
993
994 with _C.DisableTorchFunction():
--> 995 ret = func(*args, **kwargs)
996 return _convert(ret, cls)
997
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2466 if size_average is not None or reduce is not None:
2467 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
2469
2470
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2260 if input.size(0) != target.size(0):
2261 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2262 .format(input.size(0), target.size(0)))
2263 if dim == 2:
2264 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (64) to match target batch_size (640).