Mnist Basics: help me understand .backward()

i have been trying all sorts of stuff to understand how this function works.

i think if i can understand how the pytorch generates different results depending on whether .sum() or .mean() is called the penny might drop

example 1:

x = tensor(2.,6.).requires_grad_()
y = (x**2).sum()
y.backward()

# y result: tensor(40., grad_fn=<SumBackward0>)
# x.grad result: tensor([ 4., 12.])

example 2

w = tensor(2., 6.).requires_grad_()
z = (w**2).mean()
z.backward()

# z result: tensor(20., grad_fn=<meanBackward0>)
# w.grad result: tensor([2., 6.]

why do i need to use .sum() to get the correct gradients but in the mnist chapter they use .mean() ?

also why does this happen:
example 3

g = tensor(2.,6.).requires_grad_()
f = (g).mean()
f.backward()
# g.grad result: tensor([0.5000, 0.5000])

k = tensor(2.,6.,8,9).requires_grad_()
l = (k).mean()
l.backward()
k.grad
# k.grad result: tensor([0.2500, 0.2500, 0.2500, 0.2500])

the gradients calculated decrease depending proportionately to the number of values in the tensor?

How exactly are these values being calculated?

1 Like

Hi @gruffjaguar,
The difference is that in example 1 you compute the gradient of y with regards to x. y = sum_i(x_i ** 2) so y’ = 2 * sum_i(x_i) which gives you y’(x1) = y’(2) = 2 * 2 = 4 and y’(x2) = y’(6) = 2 * 6 = 12

while in example 2 you compute the gradient of z with regards to w. z = mean(w_i ** 2) = 1 / 2 * sum_i(w_i ** 2) so z’ = 1 / 2 * 2 * sum_i(w_i) which gives you z’(w1) = z’(2) = 1 / 2 * 2 * 2 = 2 and z’(w2) = z’(6) = 1 / 2 * 2 * 6 = 6.
For example 3 : l = 1 / 4 * sum_i(k_i).
Hope it helps !
Charles

1 Like

Thanks for the reply Charles. Your explanation of example 1 makes sense. I understand .sum() but .mean() is still confusing to me.

Example 3 completely ignores the values in the tensor. You can pick whatever numbers you want and it is always the same? only changing the number of values in the tensor changes the outcome?

i am still confused on example 2 also. Is it 1/2 because that’s the size of the tensor? If i had a tensor of with 3 values would it be 1/3 ?

let’s pretend i have

m = tensor(4., 7., 11.)
n = (m**5).mean()

my interpretation of what you are saying is

n'(4) = 1/3*(4*5)  = 6.66
n'(7) = 1/3*(7*5)  = 11.66
n'(11) = 1/3*(11*5) = 18.33

however i checked this on google colab and im not even close. please help this blind man see haha

i am still confused on example 2 also. Is it 1/2 because that’s the size of the tensor? If i had a tensor of with 3 values would it be 1/3 ?

Yes. Because the mean of anything is calculated by the sum of all values over the number of values.

For this function (m**5).mean(), the 1/3 part is correct (because there are three values in the tensor), but the derivative of m**5 is 5m**4. Make sure you understand how the power rule, it’s simple :slight_smile:

2 Likes

Yes, to compute the mean : if X is a tensor, mean(X) = 1 / len(X) * sum(elements of X).
You are not even close because your derivatives are wrong : derivative of X^N is : N * X^(N-1) so n’(4) = 1 / 3 * 5 * 4^4
Hope it helps.

2 Likes

:grimacing:
my memory of high school maths is a little rustier than i thought it was. This makes so much more sense.

Example 3 from my original post also makes sense now. because n**0 is always 1 so the mean of each element of the tensor is simply 1/len(tensor)

Thank you for your patience @nn.Charles and @johannesstutz