Update ParamScheduler to accept list of schedules

So far the ParamScheduler class does not accept a list of schedules for a given hyperparameter. Instead we have to pass a function f() which returns a list or a tuple of the values for each of the parameter groups. Here is the class:

class ParamScheduler(Callback):

    "Schedule hyper-parameters according to `scheds`"
    run_after,run_valid = TrainEvalCallback,False

    def __init__(self, scheds): self.scheds = scheds
    def before_fit(self): self.hps = {p:[] for p in self.scheds.keys()}
    def before_batch(self): self._update_val(self.pct_train)

    def _update_val(self, pct):
        for n,f in self.scheds.items(): self.opt.set_hyper(n, f(pct))

    def after_batch(self):
        for p in self.scheds.keys(): self.hps[p].append(self.opt.hypers[-1][p])

    def after_fit(self):
        if hasattr(self.learn, 'recorder') and hasattr(self, 'hps'): self.recorder.hps = self.hps

    _docs = {"before_fit": "Initialize container for hyper-parameters",
             "before_batch": "Set the proper hyper-parameters in the optimizer",
             "after_batch": "Record hyper-parameters of this batch",
             "after_fit": "Save the hyper-parameters in the recorder if there is one"}

Suppose I have 4 different parameter groups. If I want to have a different schedule for each parameter group I need to create a function and then create a scheduler, which I will pass to the callbacks of my Learner object:

def f_sched(*args):
    lr = 2e-3
    lrs = [lr/2., lr/2.,lr/2.,lr]
    return [combine_scheds([0.2,0.8], [SchedCos(lr/10.,lr), SchedCos(lr,lr/1e5)])(*args) for lr in lrs]

sched = {'lr': f_sched}

learner.fit(1, cbs=ParamScheduler(sched))

My question is: would not be easier to modify the code (ParamScheduler class, more specifically the method _update_val(self, pct)) to accept a list of schedules, such that we can apply the following lines of code:

sched = {'lr':  [combine_scheds([0.2,0.8], [SchedCos(lr/10.,lr), 
SchedCos(lr,lr/1e5)]) for lr in lrs]}

learner.fit(1, cbs=ParamScheduler(sched))

Thanks a lot!