Hi so I am trying to implement the Momentum as is explained in lesson 5.
However, my update function does not seem to perform better than the regular SGD.
Anyone might have an Idea what is going wrong:
def update2(x,y,lr, grads, count):
wd = 1e-5
y_hat = model(x)
# weight decay
w2 = 0.
for p in model.parameters(): w2 += (p**2).sum()
# add to regular loss
loss = loss_func(y_hat, y) + w2*wd
loss.backward()
with torch.no_grad():
for idx,p in enumerate(model.parameters()):
gradient= (grads[idx]*0.9+p.grad*0.1)
p.sub_(lr * gradient)
grads[idx]=gradient.clone()
p.grad.zero_()
return loss.item(), grads
count=0
grads=[0,0]
losses=[None]*number_of_batches
for x,y in data.train_dl:
losses[count], grads= update2(x,y,lr, grads,count)
count+=1
model.apply(weight_reset)