What is checkpoint averaging?

I saw this mentioned in the “Attention is all you need” paper but I haven’t been able to figure out…

What it is?

Whether it is implemented in fastai?

… and assuming it’s not implemented, how I should/could go about implementing it?

I’d recommend looking at the NLP course as they go over transformers (unsure if they specifically go over that paper). The notebooks are here: https://github.com/fastai/course-nlp

Otherwise it looks like there’s a pytorch repo so it should be fairly easy to plug and play in fastai

Here is a notebook with attention being used:

I understand attention. My questions above are w/r/t to checkpoint averaging specifically…

1 Like

I gather they mean checkpoints as in periodically saved weights, so they averaged the weights from a few points near the end of the run. They say they saved checkpoints every 10 minutes (p.8), presumably at the end of epochs when that much time had elapsed…
So should be easy enough to implement with fastai. There’s the SaveModelCallback though it can only do saving every epoch (or only on improvement). So you could either just save every epoch or to reduce disk use modify it to allow saving every n epochs (or minutes).
Then just take the last few saves and average them (they’re just a dict of weight tensors and, depending on save options some optimiser settings).

(This sense of checkpoint being different to torch.utils.checkpoints which is something else)

2 Likes

That makes sense.

From the paper …

For the base models, we used a single model obtained by averaging the last 5 checkpoints, which were written at 10-minute intervals. For the big models, we averaged the last 20 checkpoints.

Which I’m interpreting as …

“Every 10 minutes we saved all the model weights, and after all the training was done we took the last 5 saves for the base model (and last 20 saves for the big models) and averaged them as our ‘final model’. This ‘final model’ we then used against our validation set to produce the results in our paper”

1 Like

Yep, that was my interpretation too.