I’m using learn.model=nn.DataParallel(learn.model) as I’ve seen in the forums to try to scale up a model to train on multiple GPUs.

It seems to be working with multiple GPUs but training 1 epoch on 8x 2080ti is actually looking to be much slower than on 1x 2080ti.
I think this is because I haven’t been able to increase my batch size. If I try to increase the batch size I get a CUDA out of memory error because GPU 0 is disproportionately using more memory than the others. nvidia-smi output looks like this:
I’ve been doing some research and it looks like this is because nn.DataParallel accumulates the gradients onto a single GPU. There is some code in the last post there that purports to spread this out across all of the GPUs but I haven’t been able to get good results.
If I do learn.loss_func=CriterionParallel(learn.loss_func) as that post suggests (where CriterionParallel is lifted from the forum post) it does balance out the memory usage slightly but not much (and the estimated time for 1 epoch nearly doubles compared to not using it):
I also found a link to this project which also tries to solve the problem. But when I try to use encoding.parallel.DataParallelModel and encoding.parallel.DataParallelCriterion like this:
import encoding
learn.model = encoding.parallel.DataParallelModel(learn.model)
learn.loss_func = encoding.parallel.DataParallelCriterion(learn.loss_func)
I get an error: AttributeError: 'FlattenedLoss' object has no attribute 'parameters'. Pretty sure this is because normal Pytorch loss functions are subclasses of nn.Module whereas fastai’s FlattenedLoss doesn’t inherit from anything.

To get around this, I tried to make FlattenedLoss subclass nn.Module… but it was a dead end for me. It needed me to call super().__init__() before setting self.func but complained when I did that as well.
Anyone have a solution for balancing out memory usage so I can increase batch size to take advantage of multiple GPUs while training a language model?
Edit: found this thread that summarizes the problem and at the end someone suggests using Distributed instead of Parallel which I may try later.



