Understanding gradient calculation

The forward pass of a fully-connected layer computes the following:

 f[1] = w[1,1]*x[1] + w[2,1]*x[2] + ... + w[n,1]*x[n] + b[1]
 f[2] = w[1,2]*x[1] + w[2,2]*x[2] + ... + w[n,2]*x[n] + b[2]
 ...
 f[m] = w[1,m]*x[1] + w[2,m]*x[2] + ... + w[n,m]*x[n] + b[m]

where:

  • m is the number of neurons in that layer,
  • n is the number of inputs,
  • x is a vector of inputs (n in total),
  • w are the weights (a matrix of size n * m),
  • b are the biases (m in total),
  • and f is the output of the layer (also a vector of m elements).

(Note: I started counting at 1 here, not at 0. It doesn’t really matter.)

The backward pass takes an incoming gradient and computes three other gradients from that.

The incoming gradient is that of the loss L with respect to f, or dL/df. This is a vector of m numbers because this layer has m neurons and therefore m outputs.

The three gradients computed are:

  1. dL/dx – the loss w.r.t. the inputs of the layer
  2. dL/dw – the loss w.r.t. the weights of the layer
  3. dL/db – the loss w.r.t. the biases of the layer

Since there are n inputs, dL/dx is really a vector of n elements. dL/dW is a matrix of n * m elements, and dL/db is a vector of m elements.

To understand how to compute these vectors, it is useful to look at a single element at a time. For example, let’s look at dL/dx[1]. In the forward pass, the only terms that have x1 in them are:

w[1,1]*x[1]   in f[1]
w[1,2]*x[1]   in f[2]
...
w[1,m]*x[1]   in f[m]

Since none of the other terms in the formula for f have x[1] in them, those terms become 0 when we compute the partial derivative w.r.t. x[1] (they are treated as constants and the derivative of a constant is 0).

Thanks to the chain rule, we know that dL/dx[1] = dL/df * df/dx[1]. We already know dL/df, which is a vector of m numbers (this is the incoming gradient).

So now we have to compute df/dx[1]. This describes by how much the value of f changes if x[1] becomes larger. Since f is a vector, let’s look at how x[1] affects each term of f:

df[1]/dx[1] = w[1,1]
df[2]/dx[1] = w[1,2]
...
df[m]/dx[1] = w[1,m]

I hope this makes sense, because f[1] = w[1,1]*x[1] + ... where ... is other stuff that are considered constants when we take the derivative w.r.t. x[1].

In other words, if x[1] is incremented by 1, then the output of f[1] changes by the coefficient of x[1], which is w[1,1]. And so the derivative of f[1] w.r.t. x[1] is w[1,1]. Likewise for the other elements of f.

And also for df/dx[2] and the other elements of x:

df[1]/dx[2] = w[2,1]
df[2]/dx[2] = w[2,2]
...
df[m]/dx[2] = w[2,m]

And so on… you can see the pattern here.

Now we know what the loss is of every output element of f w.r.t. every element value input x. We can put these into an m * n matrix. And if you hadn’t guessed yet, that is exactly the weight matrix again.

So df/dx is really the weight matrix w. That makes sense because the influence that a given input x[i] has on the output f, is described by the weights between the inputs and the hidden neurons.

Finally, to get dL/dx, we multiply dL/df – which if you recall is the incoming gradient vector of m values – by df/dx, which is really just the weight matrix. Multiplying a vector of size m with a matrix of size m*n gives a new vector of size n. Which makes sense because x has n elements, and so dL/dx should also have n elements.

Phew, that was a long explanation of this single line in the above code:

inp.g = out.g @ w.t()

Here, inp.g is dL/dx, out.g is dL/df, and w.t() is the weight matrix or df/dx. (It was explained in lesson 9 why you’re taking the transpose here. In short, it’s because the weights are not stored as an m * n matrix in PyTorch but as n * m, for historical reasons.)

So now that you know how to compute dL/dx, can you figure out the math to compute dL/dw and dL/db?

It’s the exact same method: for dL/dw you figure out how much a change in each weight value w[i,j] contributes to f. And for dL/db you figure out how much a change in each bias value b[i] contributes to f.

18 Likes