Thanks @a_yasyrev! I like that very much! I think I fixed the code… @LessW2020 do you want to run it on one you’re comfortable with the results to be sure? I ran it with Ranger + SSA + MX and results were pretty much as expected, but before a pr and such I’d like a double check
@grankin can you check as well? As it is your brain child
class FlatCosAnnealScheduler(LearnerCallback):
"""
Manage FCFit training as found in the ImageNette experiments.
Code format is based on OneCycleScheduler
Based on idea by Mikhail Grankin
"""
def __init__(self, learn:Learner, lr:float=4e-3, moms:Floats=(0.95,0.999),
start_pct:float=0.72, start_epoch:int=None, tot_epochs:int=None,
curve='cosine'):
super().__init__(learn)
n = len(learn.data.train_dl)
self.anneal_start = int(n * tot_epochs * start_pct)
self.batch_finish = (n * tot_epochs - self.anneal_start)
if curve=="cosine":
curve_type=annealing_cos
elif curve=="linear":
curve_type=annealing_linear
elif curve=="exponential":
curve_type=annealing_exp
else:
raiseValueError(f"annealing type not supported {curve}")
phase0 = TrainingPhase(self.anneal_start).schedule_hp('lr', lr).schedule_hp('mom', moms[0])
phase1 = TrainingPhase(self.batch_finish).schedule_hp('lr', lr, anneal=curve_type).schedule_hp('mom', moms[1])
phases = [phase0, phase1]
self.phases,self.start_epoch = phases,start_epoch
def on_train_begin(self, epoch:int, **kwargs:Any)->None:
"Initialize the schedulers for training."
res = {'epoch':self.start_epoch} if self.start_epoch is not None else None
self.start_epoch = ifnone(self.start_epoch, epoch)
self.scheds = [p.scheds for p in self.phases]
self.opt = self.learn.opt
for k,v in self.scheds[0].items():
v.restart()
self.opt.set_stat(k, v.start)
self.idx_s = 0
return res
def jump_to_epoch(self, epoch:int)->None:
for _ in range(len(self.learn.data.train_dl) * epoch):
self.on_batch_end(True)
def on_batch_end(self, train, **kwargs:Any)->None:
"Take a step in lr,mom sched, start next stepper when the current one is complete."
if train:
if self.idx_s >= len(self.scheds): return {'stop_training': True, 'stop_epoch': True}
sched = self.scheds[self.idx_s]
for k,v in sched.items(): self.opt.set_stat(k, v.step())
if list(sched.values())[0].is_done: self.idx_s += 1
def fit_fc(learn:Learner, tot_epochs:int=None, lr:float=defaults.lr, moms:Tuple[float,float]=(0.95,0.85), start_pct:float=0.72,
wd:float=None, callbacks:Optional[CallbackList]=None, show_curve:bool=False)->None:
"Fit a model with Flat Cosine Annealing"
max_lr = learn.lr_range(lr)
callbacks = listify(callbacks)
callbacks.append(FlatCosAnnealScheduler(learn, lr, moms=moms, start_pct=start_pct, tot_epochs=tot_epochs))
learn.fit(tot_epochs, max_lr, wd=wd, callbacks=callbacks)
Edit: I believe I have working momentum
Edit x2: I do not… reverted to original