Ok so I got the dataloader to work. Basically my model needs information about the parent row and an array of the children rows. For the parent row I send the cat and cont tensor. And for the array of children row, for each row I send the cat, cont tensor too.
Here is the code for that if anyone is interested… It contains some specific stuff to my domain, like SurveyResultID stuff, but could be adapted by someone else if need be. SurveyResultID is simply the common key between my parent and children.
questions_tab = TabularPandas(questions, [Categorify, FillMissing, Normalize], questions_cat_names, questions_cont_names)
results_tab = TabularPandas(results, [Categorify, FillMissing, Normalize], results_cat_names, results_cont_names, y_names='NextScore')
class ReadMultiTabBatch(ReadTabBatch):
def encodes(self, to):
parents = super(ReadMultiTabBatch, self).encodes(to[0])
children = [super(ReadMultiTabBatch, self).encodes(x) for x in to[1]]
max_len = max([len(c[0]) for c in children])
for i, c in enumerate(children):
cat, cont = c[0], c[1]
new_cat = torch.zeros(max_len, cat.shape[1]).long()
new_cat[:len(c[0])] = c[0]
new_cont = torch.zeros(max_len, cont.shape[1]).float()
new_cont[:len(c[1])] = c[1]
children[i] = (new_cat, new_cont)
return parents[:-1], children, parents[-1]
@delegates()
class TabParentChildDataLoader(TfmdDL):
do_item = noops
def __init__(self, dataset, children, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):
if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadMultiTabBatch(dataset)
super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)
self.children = children
def create_batch(self, b):
parents = self.dataset.items.iloc[b]
c = self.children.items
c = c[c['SurveyResultID'].isin(parents['SurveyResultID'])]
res = [self.children.iloc[c[c.SurveyResultID == x].index] for x in parents.SurveyResultID]
return self.dataset.iloc[b], res
splits = IndexSplitter(valid_indexes.tolist())(range_of(results))
dl = TabParentChildDataLoader(results_tab, questions_tab, splits=splits)
ds = Datasets(dl, splits=splits)
dls = ds.dataloaders()