The forward pass of a fully-connected layer computes the following:
f[1] = w[1,1]*x[1] + w[2,1]*x[2] + ... + w[n,1]*x[n] + b[1]
f[2] = w[1,2]*x[1] + w[2,2]*x[2] + ... + w[n,2]*x[n] + b[2]
...
f[m] = w[1,m]*x[1] + w[2,m]*x[2] + ... + w[n,m]*x[n] + b[m]
where:
-
m
is the number of neurons in that layer,
-
n
is the number of inputs,
-
x
is a vector of inputs (n
in total),
-
w
are the weights (a matrix of size n * m
),
-
b
are the biases (m
in total),
- and
f
is the output of the layer (also a vector of m
elements).
(Note: I started counting at 1 here, not at 0. It doesn’t really matter.)
The backward pass takes an incoming gradient and computes three other gradients from that.
The incoming gradient is that of the loss L
with respect to f
, or dL/df
. This is a vector of m
numbers because this layer has m
neurons and therefore m
outputs.
The three gradients computed are:
-
dL/dx
– the loss w.r.t. the inputs of the layer
-
dL/dw
– the loss w.r.t. the weights of the layer
-
dL/db
– the loss w.r.t. the biases of the layer
Since there are n
inputs, dL/dx
is really a vector of n
elements. dL/dW
is a matrix of n * m
elements, and dL/db
is a vector of m
elements.
To understand how to compute these vectors, it is useful to look at a single element at a time. For example, let’s look at dL/dx[1]
. In the forward pass, the only terms that have x1
in them are:
w[1,1]*x[1] in f[1]
w[1,2]*x[1] in f[2]
...
w[1,m]*x[1] in f[m]
Since none of the other terms in the formula for f
have x[1]
in them, those terms become 0 when we compute the partial derivative w.r.t. x[1]
(they are treated as constants and the derivative of a constant is 0).
Thanks to the chain rule, we know that dL/dx[1] = dL/df * df/dx[1]
. We already know dL/df
, which is a vector of m
numbers (this is the incoming gradient).
So now we have to compute df/dx[1]
. This describes by how much the value of f
changes if x[1]
becomes larger. Since f
is a vector, let’s look at how x[1]
affects each term of f
:
df[1]/dx[1] = w[1,1]
df[2]/dx[1] = w[1,2]
...
df[m]/dx[1] = w[1,m]
I hope this makes sense, because f[1] = w[1,1]*x[1] + ...
where ...
is other stuff that are considered constants when we take the derivative w.r.t. x[1]
.
In other words, if x[1]
is incremented by 1, then the output of f[1]
changes by the coefficient of x[1]
, which is w[1,1]
. And so the derivative of f[1]
w.r.t. x[1]
is w[1,1]
. Likewise for the other elements of f
.
And also for df/dx[2]
and the other elements of x
:
df[1]/dx[2] = w[2,1]
df[2]/dx[2] = w[2,2]
...
df[m]/dx[2] = w[2,m]
And so on… you can see the pattern here.
Now we know what the loss is of every output element of f
w.r.t. every element value input x
. We can put these into an m * n
matrix. And if you hadn’t guessed yet, that is exactly the weight matrix again.
So df/dx
is really the weight matrix w
. That makes sense because the influence that a given input x[i]
has on the output f
, is described by the weights between the inputs and the hidden neurons.
Finally, to get dL/dx
, we multiply dL/df
– which if you recall is the incoming gradient vector of m
values – by df/dx
, which is really just the weight matrix. Multiplying a vector of size m
with a matrix of size m*n
gives a new vector of size n
. Which makes sense because x
has n
elements, and so dL/dx
should also have n
elements.
Phew, that was a long explanation of this single line in the above code:
inp.g = out.g @ w.t()
Here, inp.g
is dL/dx
, out.g
is dL/df
, and w.t()
is the weight matrix or df/dx
. (It was explained in lesson 9 why you’re taking the transpose here. In short, it’s because the weights are not stored as an m * n
matrix in PyTorch but as n * m
, for historical reasons.)
So now that you know how to compute dL/dx
, can you figure out the math to compute dL/dw
and dL/db
?
It’s the exact same method: for dL/dw
you figure out how much a change in each weight value w[i,j]
contributes to f
. And for dL/db
you figure out how much a change in each bias value b[i]
contributes to f
.