Implementing momentum in lesson2-sgd notebook

As suggested in lesson 5, I’m trying to add momentum to the update function in lesson2-sgd. This is my code:

a = nn.Parameter(tensor(-1.,1));
prev_grad = 0
def update(mom=0.9):
  y_hat = x@a
  loss = mse(y, y_hat)
  if t % 10 == 0: print(loss)
  loss.backward()
  with torch.no_grad():
    global prev_grad
    a.sub_(1e-1 * (1-mom)*a.grad + mom*prev_grad)
    prev_grad = (1-mom)*a.grad + mom*prev_grad
    a.grad.zero_()
for t in range(100): update(mom=0.1)
plt.scatter(x[:,0],y)
plt.scatter(x[:,0],(x@a).detach());

Is this implementation correct?

My second question is this: When I try to refactor the calculation of momentum like this:

  with torch.no_grad():
    global prev_grad
    p = (1-mom)*a.grad + mom*prev_grad
    a.sub_(lr * p)
    prev_grad = p
    a.grad.zero_()

then the training converges as slowly as not using momentum at all. Why is that?