Hi,
I have tried to implement checkpointing as it was not available yet in FastAI. The main changes I’ve done were
- the inclusion of new
CheckpointModule
- for check pointing - function
layer_config
- to get the layer details from resnet - modified version of base function -
create_body
Everything else remains the same.
class CheckpointModule(nn.Module):
def __init__(self, module, num_segments=1):
super(CheckpointModule, self).__init__()
assert num_segments == 1 or isinstance(module, nn.Sequential)
self.module = module
self.num_segments = num_segments
def forward(self, *inputs):
if self.num_segments > 1:
return checkpoint_sequential(self.module, self.num_segments, *inputs)
else:
return checkpoint(self.module, *inputs)
# To extract the sequential layers from resnet
def layer_config(arch):
"Get the layers associated with `arch`."
return model_layers.get(arch)
model_layers = {
models.resnet18 :[2, 2, 2, 2], models.resnet34: [3, 4, 6, 3],
models.resnet50 :[3, 4, 6, 3], models.resnet101:[3, 4, 23, 3],
models.resnet152:[3, 8, 36, 3]}
## Send sequential layers in custom_body to Checkpoint
def create_body(arch:Callable, pretrained:bool=True, cut:Optional[Union[int, Callable]]=None):
"Cut off the body of a typically pretrained `model` at `cut` (int) or cut the model as specified by `cut(model)` (function)."
model = arch(pretrained)
cut = ifnone(cut, cnn_config(arch)['cut'])
dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
if cut is None:
ll = list(enumerate(model.children()))
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
if isinstance(cut, int):
#Checkpoint - Changes Start
if (arch.__name__).find("resnet")==0:
# Initial 4 Layers didn't have sequential and were not applicable with Checkpoint
n = 4
layers = layer_config(arch)
out = nn.Sequential(*list(model.children())[:cut][:n],
*[CheckpointModule(x, min(checkpoint_segments, layers[i])) for i, x in enumerate(list(model.children())[:cut][n:])])
else:
out = nn.Sequential(*list(model.children())[:cut])
return out
#Checkpoint - Changes End
elif isinstance(cut, Callable): return cut(model)
else: raise NamedError("cut must be either integer or a function")
Please let me know your thoughts. Evidence of the execution.