Hi @ashwin93961, welcome to the fast.ai forums!
The short answer is that in order to update the layer weights after a forward pass, for each one of a layerâs weights we need to add up the gradients that were computed with respect to each of model inputs. We sum up in dimension 0 because this is the dimension that holds all the inputs.
More conceptually, in the simple case where we only pass one input through our model and then calculate the weight gradients of a layer with respect to that input, thereâd be no need to sum along dim=0. Weâd already have a single gradient that corresponds to each one of our layerâs weights (in the notebookâs example, these weights reside in a 784x50 weight matrix).
However, what if if we want to calculate the cumulative gradient, for each of the weights in the layerâs weight matrix, after passing several inputs through our model? Indeed, this is something weâll need to do when we use mini-batches to train a neural network.
On order to do this, weâll need to keep track of the weight gradients with respect to each model input, for all inputs that we pass through our model. Recall that the MNIST dataset used in the notebookâs example has 50,000 inputs. If we pass all of those 50,000 inputs through our layer, weâll eventually have 50,000 separate sets of 784x50 weight gradients. Each set contains the gradients of the layerâs weights with respect to a different input.
But given that we only have one set of 784x500 weights at that layer, how can we update these weights using the gradients found in all 50,000 different sets of weight gradients?
The way to do this is to sum up the weight gradients across all the 50,000 inputs. Since dim=0 is the dimension that stores model inputs, this is why we sum the weight gradients (or bias gradients) across dim=0.
Looking more deeply at the notebookâs example, letâs zoom in to the calculation of weight gradients for the first layer, l1 = inp @ w1 + b1
, as defined in the forward_and_backward()
function in the notebookâs next cell.
Recall that the layerâs weights, w1
, have a shape of torch.Size([784, 50])
, where 50 is the hidden layer size and 784 is the length of one MNIST imageâs flattened vector.
Recall also that the MNIST inputs (inp
) to the linear layer have a shape of torch.Size([50000, 784])
. Thatâs one row of length 784 for each of the 50,000 MNIST images.
Now, the operation inp.unsqueeze(-1)
adds an extra final dimension to the inputs, changing their shape from torch.Size([50000, 784])
to torch.Size([50000, 784, 1])
.
Additionally, the operation out.g.unsqueeze(1)
adds an extra dimension to out.g
at the dim=1 axis. This changes the shape of the matrix containing layerâs outputsâ gradients from torch.Size([50000, 50])
to torch.Size([50000, 1, 50])
.
Multiplying these two matrices together then results in a product that has the shape torch.Size([50000, 784, 50])
. Indeed, we added the extra dimensions so that weâd be able to successfully multiply the two matrices together. And their product contains all 50,000 sets of weight gradients â and each set is with respect to a different MNIST input.
By summing up along dim=0, we aggregate the weight gradients across all 50,000 inputs, whichâll give us a matrix of gradients of shape torch.Size([784, 50])
that our optimizer can then use to update the layerâs weights.