How do gradients get calculated?

One thing which still surprises me is that all the vectors can be differentiable. I was looking around PyTorch but couldn’t figure it out. I know that grad_fns are used but how do those also calculate the gradient? Are all possible operations and their derivatives hard coded in PyTorch and then picked in autograd? Is there any numerical technique which is used to figure out the gradient?

@spp
Short Answer to your question: yes, PyTorch knows which grad_fn corresponds to which operation to do while calculating derivatives. So if you’ve multiplied a tensor by 2, Pytorch stores its grad_fn in a manner, which tells it to multiply the gradient by 2, while doing a backward pass.

In depth explanation:
There’s no need to go into great depth about how autograd works, because it is indeed quite complex. But the basic method by which it works is - by construction of a “graph”, which is something like a tree, or a linked list( if you’ve used C++ before). Basically, all tensors have some parameters attached to them, such as ‘grad’, ‘grad_fn’, ‘requires_grad’(which is by default ‘false’), etc. When we set ‘requires_grad’ to be true, it enables ‘grad’ and ‘grad_fn’ to contain values. And when we use this tensor for calculation of another tensor, they too have ‘requires_grad’ as true.

Now, when we do .backward() , we update the grad_fn values of all tensors that are part of this graph. And this is how PyTorch comes to know how to calculate gradients. It knows how to calculate gradients when a tensor is a result of, say, multiplication of two (or more) tensors/scalars, or addition of tensors.

The list of all operations is not very big - it is quite limited, which is the reason why autograd is possible on all types of tensors and operations.

Here’s a visualization of the graph that is constucted:

Cheers

3 Likes

Thanks @PalaashAgrawal! I think this part is what I wanted confirmation on. I couldn’t find where all the operations are located and their derivatives, but I’m sure it’s somewhere there in C++ or non Python code.

The video was pretty great too. I don’t think I got all the details but essentially seems like it’s implementing the chain rule and doing all the necessary book keeping.