Deep Learning - Gradient Aggregation in Parameter Servers

I have some questions regarding parameter servers and the gradient aggregation performed. My main source is the Dive into Deep Learning book [1]. I assume the BSP model, i.e., we synchronize after each mini-batch. I wasn’t sure whether this belongs to the data science or AI StackExchange community, so I defaulted here…

Figure 12.7.1 on [1] suggests the following approach: Assume batch size 32. If we have one GPU and 128 training data points, each epoch has 4 mini-batches. After every mini-batch, we update our model (i.e. there are 4 updates). Hence, we have to calculate four gradients (one per minibatch).

For the multiple GPUs case, assume two GPUs and 128 training data points. We feed each GPU a mini-batch, let them calculate a gradient, sum them and update our model with the sum. Hence, there are two steps instead of four involved (from a BSP point of view).

My questions are the following:

  1. Is my described understanding of how parameter servers work correctly? Especially I am not sure if we keep the same batch size of 32 per GPU or if we have to divide the batch size by the number of GPUs.
  2. Why do we sum the gradients instead of averaging them? This is even more confusing to me, as, in the DistributedDataParallel documentation of PyTorch [2], there is the following statement:

When a model is trained on M nodes with batch=N, the gradient will be M times smaller when >compared to the same model trained on a single node with batch=M*N (because the gradients >between different nodes are averaged). You should take this into consideration when you want to >obtain a mathematically equivalent training process compared to the local training counterpart.

There are two things confusing here:

2.1. It states that the gradients are averaged, which does not comply with the D2L book which states that we sum the gradients.

2.2. Until now, I always thought, in mini-batch gradient descent, the loss function (optimization goal) averages the error of the mini-batch data points. Hence, if M nodes run mini-batch gradient descent with batch size N, and we take the average of their gradients, we should receive a number that is in the same order of magnitude as if 1 node runs mini-batch gradient-descent with batch size NM, as the single node averages the error functions of nm data points for the loss function, while with M nodes we just take the average of averages.

I am not sure whether the DistributedDataParallel class of PyTorch can be seen as a parameter server (especially because they even have a guide on how to build a parameter server in PyTorch [3]), but it maps to what is described in the book as a parameter server.

Any help in resolving my confusion is much appreciated. Thank you very much!