I’m wanting to make a meta pseudo labels model but not sure what the best way to do it would be in fastai (I know that you could make the training loop in pytorch but would rather not!)
edited: Having given it some thought (and inspired by the excellent @muellerzr callbacks livestream) this is what I have come up with, which is really a striped down gan module and setting the switcher to train equally
#psudocode!!
class MetaPsudoLables(Callback):
#Studnet Model
def before_batch(self):
if self.teach_mode:
x1,x2 = dls_unlabeld.batch
psudolables = teacher(xb)
old teacher loss
def after_step(self):
if not self.teach_mode:
CancelTrainException() # don't zero grads
#Teacher
def before_batch(self):
if not self.teach_mode:
CancelBatchException() # bypass `do_onebatch`
New student loss # Alternate fit, no foward pass
Move dot product
teacher MPL loss
total teacher loss = MPL loss + old teacher loss
def switch(self):
False if model.teacher_mode: else True
There’s definitely things I’ve missed in the loop but the major things that are causing problems are:
How to manage 2 models and 2 data loaders, I’ve looked a the GAN wrapper modules but don’t think I’ve understood them enough to get it working properly, setting individual model loss functions has been hard.
Adding a second unlabeled data loader in a callback, I haven’t got to coding it up but have this terrible feeling its all going to go wrong! (Also just not sure if its the best way to handle the problem)
Is training one model and then the other the best way of going about things? I can’t think of a better one but there might have been something that I’ve missed