Multiplication of gradient

(Edgar Aroutiounian) #1

In the gradient descent intro we have:

def upd():
    global a_guess, b_guess
    # make a prediction using the current weights
    y_pred = lin(a_guess, b_guess, x)
    # calculate the derivate of the loss
    dydb = 2 * (y_pred - y)
    dyda = x*dydb
    # update our weights by moving in direction of steepest descent
    a_guess -= lr*dyda.mean()
    b_guess -= lr*dydb.mean()
  1. But I don’t understand fully dyda, why does partial derivative with respect to a multiply x by dydb? We aren’t calcuating full derivative and I don’t think this is case of chain rule.

  2. In the case of more than 2 dimensions, would we still be using .mean() to decide greatest descent?

(Lorenzo Fabbri) #2

For the first question. You have that the partial derivative of the difference between label and output of the model wrt b is 2*(y_pred-y), or equivalently 2*(b + a x - y). If you take the partial derivative wrt a, you get 2* x*(b + a x - y), which is indeed x * dy/db.

(Nathan Hubens) #3
  1. When you take the derivate of your cost function according to b, you have \frac{\partial{L}}{\partial{b}} = 2 \times (y_{pred} - y). And according to a, you have \frac{\partial{L}}{\partial{a}} = 2 \times x \times (y_{pred} - y)

  2. Usually, I saw dydb and dyda computed with a dot product, with something like dyda = (2/n_samples)*, (y_pred - y)) and dybd = (2/n_samples)*,n_samples)), (y_pred-y))
    So they infer from the shape of your parameters. If you do this dot product, you can generalize to more parameters, without using the mean().

(Edgar Aroutiounian) #4

Re 2: How is # of samples relevant at all here ? Why is it coming into picture?

(Nathan Hubens) #5

The 1/n_samples is there to replicate the results of the lesson (I assume the loss function is more Mean Squared Error than Sum of Squared Error then). And the n_sample sized vector in the dydb is to get the correct shape of biases (if you look closely in the lin() function, even if b is an integer, it is summed to all 30 a*x values, it is called broadcasting)