# How do gradients get calculated?

One thing which still surprises me is that all the vectors can be differentiable. I was looking around PyTorch but couldn’t figure it out. I know that `grad_fn`s are used but how do those also calculate the gradient? Are all possible operations and their derivatives hard coded in PyTorch and then picked in `autograd`? Is there any numerical technique which is used to figure out the gradient?

@spp
Short Answer to your question: yes, PyTorch knows which `grad_fn` corresponds to which operation to do while calculating derivatives. So if you’ve multiplied a tensor by 2, Pytorch stores its grad_fn in a manner, which tells it to multiply the gradient by 2, while doing a backward pass.

In depth explanation:
There’s no need to go into great depth about how autograd works, because it is indeed quite complex. But the basic method by which it works is - by construction of a “graph”, which is something like a tree, or a linked list( if you’ve used C++ before). Basically, all tensors have some parameters attached to them, such as ‘grad’, ‘grad_fn’, ‘requires_grad’(which is by default ‘false’), etc. When we set ‘requires_grad’ to be true, it enables ‘grad’ and ‘grad_fn’ to contain values. And when we use this tensor for calculation of another tensor, they too have ‘requires_grad’ as true.

Now, when we do .backward() , we update the grad_fn values of all tensors that are part of this graph. And this is how PyTorch comes to know how to calculate gradients. It knows how to calculate gradients when a tensor is a result of, say, multiplication of two (or more) tensors/scalars, or addition of tensors.

The list of all operations is not very big - it is quite limited, which is the reason why autograd is possible on all types of tensors and operations.

Here’s a visualization of the graph that is constucted:

Cheers

3 Likes

Thanks @PalaashAgrawal! I think this part is what I wanted confirmation on. I couldn’t find where all the operations are located and their derivatives, but I’m sure it’s somewhere there in C++ or non Python code.

The video was pretty great too. I don’t think I got all the details but essentially seems like it’s implementing the chain rule and doing all the necessary book keeping.