Adding additional learnable parameters to optimizers in Callbacks

What is the recommended way to add additional parameters into the optimizer from within a callback?
This is what I came up with after reading through the docs and forums:

class ArcFaceMarginCallback(LearnerCallback):
    def on_train_begin(self, **kwargs):
        # ... omitted code...
        # self.amp extends nn.Module
        self.amp = ArcMarginProduct(arcface_in_dim,num_classes)
        # ----------------------------------
        # Attempt 1.
        # Add parameters to self.layer_groups 
        # self.learn.layer_groups = [nn.Sequential(*flatten_model(self.model))] + [ self.amp.parameters()]
    
        # Re-initialize the optimizer ## Note! This failed to work
        # self.learn.create_opt(self.opt.lr,self.opt.wd)

        # -----------------------------------
        # Attempt 2.
        # Copy existing configurations from the first parameter group
        params = { k:v for k,v in self.opt.param_groups[0].items() if k!= 'params'}
        params["params"] = self.amp.parameters()
        self.learn.opt.add_param_group(params)

    def on_loss_begin(self, **kwargs):
        outputs = kwargs['last_output']
        targets = kwargs['last_target']
        
        # Add arcface margin to outputs
        outputs = self.amp(outputs, targets)

        # Change the outputs prior to loss function
        return {'last_output': outputs}

In attempt #1. I tried adding the additional parameters into self.learn.layer_groups and calling self.learn.create_opt to re-create the optimizer but that failed with an error. I’m not sure how to add the additional parameters created from the callback and pass it into layer_groups argument of OptimWrapper.create()

Attempt#2. Works as I’m dealing directly with the pytorch optimizer API, but I’m unsure if I missed out updating anything from the FastAI library’s perspective.

Any suggestions from anyone who are more experienced?

Background for those who are interested:

I’m working on implementing ArcFace (link ) on FastAI using purely callbacks. The tricky part is that:

  • the function requires both features and labels in the forward pass(ref) prior to the loss function, so we can’t initialize the entire model and just pass it into a Learner and expect it to work right out of the box.

  • The additional parameters and softmax classifier is actually discarded during inference, as we only care about the output of the CNN encoder after training.

  • Thus using a callback would be ideal so after training we could just export the model as usual without having to discard any final layers that were added on for training purposes

In attempt #1. I tried adding the additional parameters into self.learn.layer_groups and calling self.learn.create_opt to re-create the optimizer but that failed with an error.

What error was this? How did you add the additional parameters?

self.amp is just a class extending nn.Module and contains a fully connected layer in it, so it has weights.

@kelvink
self.amp.parameters() cannot work as a layer group. A layer group is a sequential module, not a parameter generator.
Hope that helps.