I want to better understand how the dataloaders are actually passing data to the model… To better understand I thought I would simply create a batch from tabular data dataloader, pass this to a TabularModel, set some breakpoints in there. But I am having some problems.
results_tab = TabularPandas(results, [Categorify, FillMissing, Normalize], results_cat_names, results_cont_names, y_names='NextScore')
rdls = results_tab.dataloaders(bs=16)
results_model = TabularModel(results_emb_szs, len(results_cont_names), 100, [1000, 250], ps=[0.01, 0.1], embed_p=0.04).cuda()
results_model(rdls.one_batch())
The code crash while calling the model:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-88-3c075dd5c894> in <module>
3 results_model = TabularModel(results_emb_szs, len(results_cont_names), 100, [1000, 250], ps=[0.01, 0.1], embed_p=0.04).cuda()
4
----> 5 results_model(rdls.one_batch())
~\Anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
c:\work\ml\fastai2\fastai2\tabular\model.py in forward(self, x_cat, x_cont)
46 def forward(self, x_cat, x_cont=None):
47 if self.n_emb != 0:
---> 48 x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
49 x = torch.cat(x, 1)
50 x = self.emb_drop(x)
c:\work\ml\fastai2\fastai2\tabular\model.py in <listcomp>(.0)
46 def forward(self, x_cat, x_cont=None):
47 if self.n_emb != 0:
---> 48 x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
49 x = torch.cat(x, 1)
50 x = self.emb_drop(x)
TypeError: tuple indices must be integers or slices, not tuple
Looking at the returned value from one_batch() I can see that this is a tuple with 3 things, (the categorical variables, the continuous variables, the targets). The model expects only the categorical variables and the continuous variables…
What am I doing wrong?
Thanks,