Text LM with large corpus - memory management and parallelization

I have been trying for some time now to craft a way to train a text model using a large corpus for a project of mine. I would like to more effectively manage the memory than loading all the data via the databunch at once. Prepping the databunch takes ~3 days and 150 GB of ram. It then crashed when trying to save it as a pickle file. I could really use some suggestions!

The pytorch IterableDataset looks like a wonderful solution: https://medium.com/swlh/how-to-use-pytorch-dataloaders-to-work-with-enormously-large-text-files-bbd672e955a0

but the way text is prepped in fastai is to construct a batch of dimensions bs*bptt. When loading text a line at a time via an IterableDataset, there is no way to handle the bptt dimension. Do others see a solution here?

I have also considered breaking my full dataset into 8 sub data sets, constructing databunches for each, and then somehow iterating over them during training via a callback. However, this seems a bit too hacky and I would like to hear other’s thoughts before I try this.

At the end, I will need to train this model over multiple GPUs. My experience so far suggests parallel jobs in FastAI resemble embarrassingly parallel simulations until the fit command. Hence, each process will need to load the data to construct the learner. Even if I could get the databunch to fit on one node, I suspect going to multi GPU training will quickly run out of memory due to each process loading the full databunch.

Any suggestions and tips would be greatly appreciated.