I don’t have experience with type of situation but suggesting a couple of things:
I found this Forums post which seems similar to what you are trying to implement—they don’t share their solution but seems like they use callbacks.
I also prompted ChatGPT and implemented a minimal solution it provided in this Colab notebook in which I create a custom loss function (to mimic ArcFaceLoss
having its own parameters) and a custom learner (which passes both the model’s parameters and the loss function’s parameters to the optimizer). It does successfully train and the loss function’s parameters do change after training which I think indicates they have been learned. Not sure if that’s what you are looking for, hope it helps!