I am trying to train a custom regression model on tabular data
For this I used the following code for creating a Databunch:
test = (TabularList.from_df(x.iloc[start_idx_test:end_idx_test].copy(), path=path1, cont_names=cont_names))
data = (TabularList.from_df(x, path=path1, cont_names=cont_names, procs=procs)
.split_by_idx(list(range(start_idx_test,end_idx_test)))
.split_none()
.split_by_rand_pct(0.2)
.label_from_df(cols=pred_val)
.add_test(test, label = 0)
.databunch())
When training the model I get this error:
AssertionError Traceback (most recent call last)
in
----> 1 learn.fit_one_cycle(1,slice(1e-02))
/opt/conda/lib/python3.7/site-packages/fastai/train.py in fit_one_cycle(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)
21 callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
22 final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
---> 23 learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks)
24
25 def fit_fc(learn:Learner, tot_epochs:int=1, lr:float=defaults.lr, moms:Tuple[float,float]=(0.95,0.85), start_pct:float=0.72,
/opt/conda/lib/python3.7/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
198 else: self.opt.lr,self.opt.wd = lr,wd
199 callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
--> 200 fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
201
202 def create_opt(self, lr:Floats, wd:Floats=0.)->None:
/opt/conda/lib/python3.7/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
104 if not cb_handler.skip_validate and not learn.data.empty_val:
105 val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,
--> 106 cb_handler=cb_handler, pbar=pbar)
107 else: val_loss=None
108 if cb_handler.on_epoch_end(val_loss): break
/opt/conda/lib/python3.7/site-packages/fastai/basic_train.py in validate(model, dl, loss_func, cb_handler, pbar, average, n_batch)
61 if not is_listy(yb): yb = [yb]
62 nums.append(first_el(yb).shape[0])
---> 63 if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break
64 if n_batch and (len(nums)>=n_batch): break
65 nums = np.array(nums, dtype=np.float32)
/opt/conda/lib/python3.7/site-packages/fastai/callback.py in on_batch_end(self, loss)
306 "Handle end of processing one batch with `loss`."
307 self.state_dict['last_loss'] = loss
--> 308 self('batch_end', call_mets = not self.state_dict['train'])
309 if self.state_dict['train']:
310 self.state_dict['iteration'] += 1
/opt/conda/lib/python3.7/site-packages/fastai/callback.py in __call__(self, cb_name, call_mets, **kwargs)
248 "Call through to all of the `CallbakHandler` functions."
249 if call_mets:
--> 250 for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
251 for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
252
/opt/conda/lib/python3.7/site-packages/fastai/callback.py in _call_and_update(self, cb, cb_name, **kwargs)
239 def _call_and_update(self, cb, cb_name, **kwargs)->None:
240 "Call `cb_name` on `cb` and update the inner state."
--> 241 new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
242 for k,v in new.items():
243 if k not in self.state_dict:
/opt/conda/lib/python3.7/site-packages/fastai/callback.py in on_batch_end(self, last_output, last_target, **kwargs)
342 if not is_listy(last_target): last_target=[last_target]
343 self.count += first_el(last_target).size(0)
--> 344 val = self.func(last_output, *last_target)
345 if self.world:
346 val = val.clone()
/opt/conda/lib/python3.7/site-packages/fastai/metrics.py in root_mean_squared_error(pred, targ)
85 def root_mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
86 "Root mean squared error between `pred` and `targ`."
---> 87 pred,targ = flatten_check(pred,targ)
88 return torch.sqrt(F.mse_loss(pred, targ))
89
/opt/conda/lib/python3.7/site-packages/fastai/torch_core.py in flatten_check(out, targ)
377 "Check that `out` and `targ` have the same number of elements and flatten them."
378 out,targ = out.contiguous().view(-1),targ.contiguous().view(-1)
--> 379 assert len(out) == len(targ), f"Expected output and target to have the same number of elements but got {len(out)} and {len(targ)}."
380 return out,targ
381
AssertionError: Expected output and target to have the same number of elements but got 448 and 64
I was able to narrow down the problem to setting the validation set in the Data Block API as when I use the split_none()
there is no problem in training the model.
How do I fix this? Please help.