As suggested in lesson 5, I’m trying to add momentum to the update function in lesson2-sgd. This is my code:
a = nn.Parameter(tensor(-1.,1));
prev_grad = 0
def update(mom=0.9):
y_hat = x@a
loss = mse(y, y_hat)
if t % 10 == 0: print(loss)
loss.backward()
with torch.no_grad():
global prev_grad
a.sub_(1e-1 * (1-mom)*a.grad + mom*prev_grad)
prev_grad = (1-mom)*a.grad + mom*prev_grad
a.grad.zero_()
for t in range(100): update(mom=0.1)
plt.scatter(x[:,0],y)
plt.scatter(x[:,0],(x@a).detach());
Is this implementation correct?
My second question is this: When I try to refactor the calculation of momentum like this:
with torch.no_grad():
global prev_grad
p = (1-mom)*a.grad + mom*prev_grad
a.sub_(lr * p)
prev_grad = p
a.grad.zero_()
then the training converges as slowly as not using momentum at all. Why is that?