Profiling GPU memory usage

I put some code to profile memory usage on forward and backwards, also added torch.utils.checkpoints support. It makes extensive use of the Hook class to access the model. It is missing the register_forward_pre_hook because the Hook does not have it (I could PR this).
You can check the notebook here.

It is something I don’t fully understand, but I am getting lower memory usage on my ResNet using this trick.
Resnet18 memory usage on forward/backwards pass normal vs sequential_checkpoints.

Most of this comes from here

I duplicated the post, cause I posted this on the V1 thread.

3 Likes