How to train on tabular data in mini batches?

I have tabular data that is far too large to fit in memory (240+ million rows). The data is in numerous parquet files of ~400k rows each. Is there a way to set up a data loader to, say, train on the data from one parquet file in each step?

I’m coming from PyTorch, so I’ll describe what I would do there. I I would provide a DataLoader a list of parquet files, and the collate function would read one file with pandas, convert values and target to a tensor, and return. Is there analogous functionality in FastAI? All I could find in the docs indicate that my data would have to be either a single dataframe in memory, or a single CSV

This topic here may provide a few pointers.

I’ve found a method that may work. I set up a callback that re-initializes the learner’s dls after every epoch, and consider an epoch to be a single parquet file. Here is an example:

def parquet_dir_gen(): 
    <Generator that yields a parquet file's directory>

def load_batch(parquet):
    <Read and preprocess your batch of data>

def create_tabular_pandas(df):
    <Returns a TabularPandas object>


class ReinitializeTabularPandas(Callback):
    def __init__(self, learn, generator, load_batch, create_dataset):
        super().__init__()
        self.learn = learn
        self.generator = generator
        self.load_batch = load_batch
        self.create_dataset = create_dataset
               

    def before_batch(self):
        
        # Load the new data, create a TabularPandas using it, and get the dataloader
        parquet = next(self.generator)
        df = self.load_batch(parquet)
        data = self.create_dataset(df)
        self.learn.dls = data.dataloaders(bs=len(data))

# Initial TabularPandas initialization
df = load_batch(inutial_parquet)
data = create_tabular_data(df)
dls = data.dataloaders(bs=len(data))

learn = Learner(dls, model, BCEWithLogitsLossFlat(), opt_func=Adam, lr=3e-2, metrics=[accuracy])
learn.add_cb( ReinitializeTabularPandas(learn, parquet_dir_gen, load_batch, create_tabular_pandas) )
learn.fit(n_epoch = num_parquet*n_epoch_desired)