GPU Memory profiling tools - recommendations?

I’m working with EfficientDet in TF but it’s sucking up large amounts of memory for anything over D1.
Are there any recommendations for tools/process on how to debug/profile this to figure out what is soaking up so much memory on the GPU?

I run into similar issues actually. I usually use pdb and just step through the code in a sort of binary search method to see how the memory usage changes. This may be unrealistic in TF or for big projects.

I do like using weights and biases to log maximum memory usage throughout the training loop to determine if the there is a memory leak.

If there is no memory leak, I then determine if it is the base model(single layers, or output of layers), the backward pass(calculating the gradients), or the optimizer step(adam statistics) that is eating the majority of the memory.

For the base model in pytorch I find that authors can leave variables sitting around in their model without allowing the garbage collector a chance to grab them. Splitting up a layer’s forward pass into separate functions can allow the gc to clear things up.

If you are using most of your memory in the backward pass, then it might be a good idea to look into gradient checkpointing.

I am not familiar with TF, but from my understanding you will generally not run out of memory in the optimizer step. For pytorch you can forcefully combine the optimizer step with clearing the gradients, so that you go from smallest to largest parameter, doing gradient update and zero_grad at the same time.

Another issue I have seen is very large linear/dense layers. You can split these linear layers into N layers, sending input into each of them, and then concatenating them before sending to the next layer. You should probably also make sure they are initialized appropriately. This can greatly decrease memory usage “spikes”.

I don’t know if any of this will be useful, but I wanted to share my experience just in case.


Thanks @marii - this is a very helpful strategy!