How can I create fine_tune checkpoints to resume from?


Is there a way to save checkpoints during fine tuning so that in case of some unwanted termination I can just restart the process close to where it left?

I am looking into SaveModelCallback, but I am not sure that is the right approach.


Yes, SaveModelCallback is the way to go:

learn.fine_tune(10, lr, cbs=cbs)

You can load the checkpoint with



1 Like

Thanks @florianl. I think that solves half my problem. What I understand is with that I would load the model, but if I want to resume the training where it left, I need more than that.

Lets say I am calling fine_tune for 10 epochs, but the process halts at epoch 6, I would like to resume from epoch 6 and continue to the end. Since epoch 6 was not finished I would load the model from epoch 5, but to continue the process I need to know the hyperparamters for epochs 6 to 9. I am not sure if it is just a matter of calculating the learning rate for epoch 6 and calling fine_tune using that as the starting learning rate for the next 4 epochs, something like this:

learn = cnn_learner(dls, resnet34, metrics=accuracy, path='checkpoints')
learn.fine_tune(10, lr, cbs=cbs) 

… something terminates learning during epoch 6, then to resume

learn = cnn_learner(dls, resnet34, metrics=accuracy, path='checkpoints')
learn.fine_tune(4, *lr_for_epoch_6*, cbs)

In fastai1 there was a parameter start_epoch for fit_one_cycle but it’s not there anymore in fastai2. But I found the following thread:

if you started with fine_tune you have to keep in mind, that fine_tune runs fit_one_cycle twice with different parameters. so you need to run it with the same parameters again to continue your training.

from the fastai code:

def fine_tune(self:Learner, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,
              pct_start=0.3, div=5.0, **kwargs):
    "Fine tune with `freeze` for `freeze_epochs` then with `unfreeze` from `epochs` using discriminative LR"
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    base_lr /= 2
    self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)

So I would do this to resume e.g. from epoch 10 and train 5 additional epochs:

class SkipToEpoch(Callback):
    def __init__(self,s_epoch): self.s_epoch = s_epoch
    def begin_train(self):  if self.epoch < self.s_epoch: raise CancelEpochException
    def begin_validate(self):  if self.epoch < self.s_epoch: raise CancelValidException

learn = ....


base_lr = <your LR>
base_lr /= 2
lr_mult = 100 # from fine_tune
# parameters pct_start and div taken from fine_tune()
learn.fit_one_cycle(total_epochs, slice(base_lr/lr_mult, base_lr), pct_start=0.3, div=5.0, cbs=cbs)

Hope that helps.