Implementing the DFW optimizer - need to pass in loss to opt.step - how?

I am working on expanding the testing on a new DFW (deep Frank Wolfe) optimizer that did really well in testing on MNIST.
Now I want to move it into the part2 latest framework and test it with XResNet and ImageNette.
However, this optimizer takes the loss and feeds that into the step function to determine it’s own adjustments.
In reviewing the optimizer class and learner class, I’m not clear how to do that short of just making a forked/customized learner directly…but wanted to see if there’s a more elegant way.
Here’s what would be needed (inside Learner):

        self.loss.backward();                           self('after_backward')
        self.opt.step(lambda: float(self.loss));        self('after_step')

Any input or advice on how to do this within the framework (i.e. inheriting from the current optimizer class) would be appreciated. It seems I could grab the loss via the callback in after_loss event, but handing that to the optimizer is not clear how to do…or I just have to make a new learner param that flags to pass the loss through in step?
Otherwise I’ll just fork it as a derived learner class, or put that flag in, but still seems like I’m missing something to accomplish this within the framework.

1 Like