Based on the OneCycleScheduler
this seems to do the trick:
class OneCycleXScheduler(LearnerCallback):
def __init__(self, learn:Learner, X_max:float=1.0, div_factor:float=25., pct_start:float=0.75,
final_div:float=None, tot_epochs:int=None, start_epoch:int=None):
super().__init__(learn)
self.X_max,self.div_factor,self.pct_start,self.final_div = X_max,div_factor,pct_start,final_div
if self.final_div is None: self.final_div = div_factor*1e4
if is_listy(self.X_max): self.X_max = np.array(self.X_max)
self.start_epoch, self.tot_epochs = start_epoch, tot_epochs
def steps(self, *steps_cfg:StartOptEnd):
"Build anneal schedule for all of the parameters."
return [Scheduler(step, n_iter, func=func)
for (step,(n_iter,func)) in zip(steps_cfg, self.phases)]
def on_train_begin(self, n_epochs:int, epoch:int, **kwargs:Any)->None:
"Initialize our optimization params based on our annealing schedule."
self.start_epoch = ifnone(self.start_epoch, epoch)
self.tot_epochs = ifnone(self.tot_epochs, n_epochs)
n = len(self.learn.data.train_dl) * self.tot_epochs
a1 = int(n * self.pct_start)
a2 = n-a1
self.phases = ((a1, annealing_cos), (a2, annealing_no)) # CHANGE HERE FOR FUNCTION! annealing_cos, annealing_linear
low_X = self.X_max/self.div_factor
self.X_scheds = self.steps((low_X, self.X_max), (self.X_max, self.X_max/self.final_div))
self.opt = self.learn.opt
self.opt.X = self.X_scheds[0].start
self.idx_s = 0
self.opt.Xs = []
def jump_to_epoch(self, epoch:int)->None:
for _ in range(len(self.learn.data.train_dl) * epoch):
self.on_batch_end(True)
def on_batch_end(self, train, **kwargs:Any)->None:
"Take one step forward on the annealing schedule for the optim params."
if train:
self.opt.X = -self.X_scheds[self.idx_s].step()
self.opt.Xs.append(self.opt.X)
if self.X_scheds[self.idx_s].is_done:
self.idx_s += 1
However, I am not sure if this is the best and most elegant way.
If somebody has suggestions, tips, or tricks, I am happy to hear them!