Multiplication of gradient

In the gradient descent intro we have:

def upd():
    global a_guess, b_guess
    # make a prediction using the current weights
    y_pred = lin(a_guess, b_guess, x)
    # calculate the derivate of the loss
    dydb = 2 * (y_pred - y)
    dyda = x*dydb
    # update our weights by moving in direction of steepest descent
    a_guess -= lr*dyda.mean()
    b_guess -= lr*dydb.mean()
  1. But I don’t understand fully dyda, why does partial derivative with respect to a multiply x by dydb? We aren’t calcuating full derivative and I don’t think this is case of chain rule.

  2. In the case of more than 2 dimensions, would we still be using .mean() to decide greatest descent?

For the first question. You have that the partial derivative of the difference between label and output of the model wrt b is 2*(y_pred-y), or equivalently 2*(b + a x - y). If you take the partial derivative wrt a, you get 2* x*(b + a x - y), which is indeed x * dy/db.

1 Like
  1. When you take the derivate of your cost function according to b, you have \frac{\partial{L}}{\partial{b}} = 2 \times (y_{pred} - y). And according to a, you have \frac{\partial{L}}{\partial{a}} = 2 \times x \times (y_{pred} - y)

  2. Usually, I saw dydb and dyda computed with a dot product, with something like dyda = (2/n_samples)*, (y_pred - y)) and dybd = (2/n_samples)*,n_samples)), (y_pred-y))
    So they infer from the shape of your parameters. If you do this dot product, you can generalize to more parameters, without using the mean().

1 Like

Re 2: How is # of samples relevant at all here ? Why is it coming into picture?

The 1/n_samples is there to replicate the results of the lesson (I assume the loss function is more Mean Squared Error than Sum of Squared Error then). And the n_sample sized vector in the dydb is to get the correct shape of biases (if you look closely in the lin() function, even if b is an integer, it is summed to all 30 a*x values, it is called broadcasting)

1 Like

Thanks for update and quick reply. I’ll be sure to keep an eye on this thread. Looking for the same info.

MyCCPay Login