I put some code to profile memory usage on forward and backwards, also added
torch.utils.checkpoints support. It makes extensive use of the
Hook class to access the model. It is missing the
register_forward_pre_hook because the
Hook does not have it (I could PR this).
You can check the notebook here.
It is something I don’t fully understand, but I am getting lower memory usage on my ResNet using this trick.
Resnet18 memory usage on forward/backwards pass normal vs sequential_checkpoints.
Most of this comes from here
I duplicated the post, cause I posted this on the V1 thread.