I put some code to profile memory usage on forward and backwards, also added torch.utils.checkpoints
support. It makes extensive use of the Hook
class to access the model. It is missing the register_forward_pre_hook
because the Hook
does not have it (I could PR this).
You can check the notebook here.
It is something I don’t fully understand, but I am getting lower memory usage on my ResNet using this trick.
Resnet18 memory usage on forward/backwards pass normal vs sequential_checkpoints.
Most of this comes from here
I duplicated the post, cause I posted this on the V1 thread.