I have some questions regarding parameter servers and the gradient aggregation performed. My main source is the Dive into Deep Learning book [1]. I assume the BSP model, i.e., we synchronize after each mini-batch. I wasn’t sure whether this belongs to the data science or AI StackExchange community, so I defaulted here…
Figure 12.7.1 on [1] suggests the following approach: Assume batch size 32. If we have one GPU and 128 training data points, each epoch has 4 mini-batches. After every mini-batch, we update our model (i.e. there are 4 updates). Hence, we have to calculate four gradients (one per minibatch).
For the multiple GPUs case, assume two GPUs and 128 training data points. We feed each GPU a mini-batch, let them calculate a gradient, sum them and update our model with the sum. Hence, there are two steps instead of four involved (from a BSP point of view).
My questions are the following:
- Is my described understanding of how parameter servers work correctly? Especially I am not sure if we keep the same batch size of 32 per GPU or if we have to divide the batch size by the number of GPUs.
- Why do we sum the gradients instead of averaging them? This is even more confusing to me, as, in the DistributedDataParallel documentation of PyTorch [2], there is the following statement:
When a model is trained on M nodes with batch=N, the gradient will be M times smaller when >compared to the same model trained on a single node with batch=M*N (because the gradients >between different nodes are averaged). You should take this into consideration when you want to >obtain a mathematically equivalent training process compared to the local training counterpart.
There are two things confusing here:
2.1. It states that the gradients are averaged, which does not comply with the D2L book which states that we sum the gradients.
2.2. Until now, I always thought, in mini-batch gradient descent, the loss function (optimization goal) averages the error of the mini-batch data points. Hence, if M nodes run mini-batch gradient descent with batch size N, and we take the average of their gradients, we should receive a number that is in the same order of magnitude as if 1 node runs mini-batch gradient-descent with batch size NM, as the single node averages the error functions of nm data points for the loss function, while with M nodes we just take the average of averages.
I am not sure whether the DistributedDataParallel class of PyTorch can be seen as a parameter server (especially because they even have a guide on how to build a parameter server in PyTorch [3]), but it maps to what is described in the book as a parameter server.
Any help in resolving my confusion is much appreciated. Thank you very much!