Lesson - 05: Implementing Momentum in the update function

Hi so I am trying to implement the Momentum as is explained in lesson 5.
However, my update function does not seem to perform better than the regular SGD.

Anyone might have an Idea what is going wrong:

def update2(x,y,lr, grads, count):
    wd = 1e-5
    y_hat = model(x)
    # weight decay
    w2 = 0.
    for p in model.parameters(): w2 += (p**2).sum()
    # add to regular loss
    loss = loss_func(y_hat, y) + w2*wd
    loss.backward()
    
     
    with torch.no_grad():
         for idx,p in enumerate(model.parameters()):
            gradient= (grads[idx]*0.9+p.grad*0.1)
            p.sub_(lr * gradient)
            grads[idx]=gradient.clone() 
            p.grad.zero_()
       
    return loss.item(), grads

  count=0
  grads=[0,0]
  losses=[None]*number_of_batches
  for x,y in data.train_dl:
       losses[count], grads= update2(x,y,lr, grads,count)
       count+=1

  model.apply(weight_reset)