Is there a way to save checkpoints during fine tuning so that in case of some unwanted termination I can just restart the process close to where it left?
I am looking into SaveModelCallback, but I am not sure that is the right approach.
Thanks @florianl. I think that solves half my problem. What I understand is with that I would load the model, but if I want to resume the training where it left, I need more than that.
Lets say I am calling fine_tune for 10 epochs, but the process halts at epoch 6, I would like to resume from epoch 6 and continue to the end. Since epoch 6 was not finished I would load the model from epoch 5, but to continue the process I need to know the hyperparamters for epochs 6 to 9. I am not sure if it is just a matter of calculating the learning rate for epoch 6 and calling fine_tune using that as the starting learning rate for the next 4 epochs, something like this:
In fastai1 there was a parameter start_epoch for fit_one_cycle but it’s not there anymore in fastai2. But I found the following thread:
if you started with fine_tune you have to keep in mind, that fine_tune runs fit_one_cycle twice with different parameters. so you need to run it with the same parameters again to continue your training.
from the fastai code:
def fine_tune(self:Learner, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,
pct_start=0.3, div=5.0, **kwargs):
"Fine tune with `freeze` for `freeze_epochs` then with `unfreeze` from `epochs` using discriminative LR"
self.freeze()
self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
base_lr /= 2
self.unfreeze()
self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)
So I would do this to resume e.g. from epoch 10 and train 5 additional epochs:
class SkipToEpoch(Callback):
def __init__(self,s_epoch): self.s_epoch = s_epoch
def begin_train(self): if self.epoch < self.s_epoch: raise CancelEpochException
def begin_validate(self): if self.epoch < self.s_epoch: raise CancelValidException
learn = ....
learn.load('checkpoint.pkl')
start_epoch=10
total_epochs=15
cbs=[SkipToEpoch(s_epoch=start_epoch)]
base_lr = <your LR>
base_lr /= 2
lr_mult = 100 # from fine_tune
# parameters pct_start and div taken from fine_tune()
learn.unfreeze()
learn.fit_one_cycle(total_epochs, slice(base_lr/lr_mult, base_lr), pct_start=0.3, div=5.0, cbs=cbs)