I want to understand what exactly happens when we do loss.backward()
in the update function how does it update the gradient for a:
So, how can we rewrite this function without using loss.backward()
and a.grad
?
def update():
y_hat = x@a
loss = mse(y, y_hat)
loss.backward()
with torch.no_grad():
a.sub_(lr * a.grad)
a.grad.zero_()