Torch.utils.checkpoint - implementation in FastAI

Hi,

I have tried to implement checkpointing as it was not available yet in FastAI. The main changes I’ve done were

  • the inclusion of new CheckpointModule - for check pointing
  • function layer_config - to get the layer details from resnet
  • modified version of base function - create_body

Everything else remains the same.

class CheckpointModule(nn.Module):
    def __init__(self, module, num_segments=1):
        super(CheckpointModule, self).__init__()
        assert num_segments == 1 or isinstance(module, nn.Sequential)
        self.module = module
        self.num_segments = num_segments

    def forward(self, *inputs):
        if self.num_segments > 1:
            return checkpoint_sequential(self.module, self.num_segments, *inputs)
        else:
            return checkpoint(self.module, *inputs)

# To extract the sequential layers from resnet
def layer_config(arch):
    "Get the layers associated with `arch`."
    return model_layers.get(arch)

model_layers = {
    models.resnet18 :[2, 2, 2, 2], models.resnet34: [3, 4, 6, 3],
    models.resnet50 :[3, 4, 6, 3], models.resnet101:[3, 4, 23, 3],
    models.resnet152:[3, 8, 36, 3]}

## Send sequential layers in custom_body to Checkpoint
def create_body(arch:Callable, pretrained:bool=True, cut:Optional[Union[int, Callable]]=None):
    "Cut off the body of a typically pretrained `model` at `cut` (int) or cut the model as specified by `cut(model)` (function)."
    model = arch(pretrained)
    cut = ifnone(cut, cnn_config(arch)['cut'])
    dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if   isinstance(cut, int):

      #Checkpoint - Changes Start
      if (arch.__name__).find("resnet")==0:                                                                                    
        # Initial 4 Layers didn't have sequential and were not applicable with Checkpoint
        n = 4 
        layers = layer_config(arch)
        out = nn.Sequential(*list(model.children())[:cut][:n],
                            *[CheckpointModule(x, min(checkpoint_segments, layers[i])) for i, x in     enumerate(list(model.children())[:cut][n:])])
      else:
        out = nn.Sequential(*list(model.children())[:cut])
      return out
    #Checkpoint - Changes End

    elif isinstance(cut, Callable): return cut(model)
    else:                           raise NamedError("cut must be either integer or a function")

Please let me know your thoughts. Evidence of the execution.

1 Like

we could revive this to V2?
I put a notebook about memory usage profiling here

2 Likes