I’m trying to create a knowledge distillation/self training process with a teacher and student model, and I would like to create a fit_distill
function based off the fit
function that takes the model outputs from a teacher and student model and calculates the distillation loss, which also requires changing the loss_batch
function. I have an idea on how to create the fit_distill
function, but once that’s done, how can I call learn.fit_distill
? Is there a way to add a function into the learner class as a method?
Look at the source code for fit_one_cycle. You’ll see that it actually takes in a learner, so you could do fit_one_cycle(learn, 5, 1e-3)
An easier one to read and understand may be the flat_cos_anneal fit function: