Fastai2 tabular for out of memory datasets

Hi, what is the right way of feedin large dataframes into training loop?
I am frustrated with different options among those are:

  • data block API
  • callbacks
  • own dataset class or dataloader

The naive way I could do this is to loop through the data like:

for chunk in chunks:
    to = TabularPandas(chunk,....)
    dls = to.dataloaders(bs=bs)
    learn = tabular_learner(...)
    learn.fit_one_cycle(...)

But I understand it will now be identical to the full fit_one_cycle.

What I want to do exactly is to merge data into wide df on the go and feed it to the training loop, so that all the data are fed during single epoch.

Please advice on how to better approach this and sorry if it looks noob.

That’s not much info to go off of, but you can pass inplace=True to TabularPandas and that should help some. You need to set Pandas chaining_mode to None though first.

(pd.options.mode.chained_assignment=None)

Hi, answering my own question, I seem to have found the right way of feeding tabular data that do not fit into memory as one df. That is done by using Callback that resets learn.dls with a new training chunk every epoch.

creating initial dls

# start_df  contains validation rows as well as the first chunk of training set

to = TabularPandas(start_df, procs, cat_names, cont_names, y_names="salary", y_block = CategoryBlock(), 
               splits=splits, do_setup=True)
trn_dl = TabDataLoader(to.train)
val_dl = TabDataLoader(to.valid)
dls = DataLoaders(trn_dl, val_dl).cuda()

Callback

# train_chunk_generator returns next chunk of training data

class ReloadCallback(Callback):
    def begin_epoch(self): 
        df = next(next_chunk)
        to_new = to.new(df)
        to_new.process()
        trn_dl = TabDataLoader(to_new.train)
        val_dl = TabDataLoader(to.valid)
        self.learn.dls =  DataLoaders(trn_dl, val_dl).cuda()

learn object and applying the callback

learn = tabular_learner(dls, loss_func=CrossEntropyLossFlat(), metrics=[accuracy])
learn.add_cb(ReloadCallback())

now it works !

learn.fit_one_cycle(10)
7 Likes