Very interesting, thanks for sharing this info. Is your implementation publicly available?
I just pushed to a branch. I want to clean up some of the handling of the model init sequence vs checkpoint load handling (if I can). Things get a bit complicated maintaining a separate set of weights (and being able to run validation on them) while using things like distributed training, DataParellel, AMP, etc…
It’ll use some GPU memory, but I experimented with a flag that keeps the EMA on the CPU only, but you have to validate those results manually from the checkpoints. Also, pay attention to the decay factor, watch how it relates to your batch size (update count per epoch). Google trains on TPU with big batches and uses .9999, on less capable systems you’ll want to reduce that unless you feel averaging over 10 epochs is useful The ‘N-day’ EMA equivalence formula is useful for making that adjustment sensible. I’m using .9998 right now.
https://github.com/rwightman/pytorch-image-models/blob/ema-cleanup/utils.py#L208