What does torch.no_grad mean?

def update():

y_hat=x@a

loss=mse(y_hat,y)

loss.backward()

if t%10==0:

print(loss)

with torch.no_grad():

a.sub_(lr*a.grad)

a.grad_zero()

Pytorch keeps track of all operations that you carry over your tensors. Thats why you need not write your own derivatives. Backward() takes care of that. But while updating your parameters, you dont want to keep track of those steps (ie. param=param - lr*grad). because otherwise, the next time you call .backward(), the model will also take derivative with respect to the update steps. So to temporarily disable tracking of the operations, we use torch.no_grad. Hope this helps

2 Likes

Thank you so much Palaash, that cleared my doubt