Sort of…
There’s now a pytorch implementation https://github.com/kekmodel/MPL-pytorch which I have been using.
edited: Having given it some thought (and inspired by the excellent @muellerzr callbacks livestream) this is what I have come up with, which is really a striped down gan module and setting the switcher to train equally
#psudocode!!
class MetaPsudoLables(Callback):
#Studnet Model
def before_batch(self):
if self.teach_mode:
x1,x2 = dls_unlabeld.batch
psudolables = teacher(xb)
old teacher loss
def after_step(self):
if not self.teach_mode:
CancelTrainException() # don't zero grads
#Teacher
def before_batch(self):
if not self.teach_mode:
CancelBatchException() # bypass `do_onebatch`
New student loss # Alternate fit, no foward pass
Move dot product
teacher MPL loss
total teacher loss = MPL loss + old teacher loss
def switch(self):
False if model.teacher_mode: else True
There’s definitely things I’ve missed in the loop but the major things that are causing problems are:
-
How to manage 2 models and 2 data loaders, I’ve looked a the GAN wrapper modules but don’t think I’ve understood them enough to get it working properly, setting individual model loss functions has been hard.
-
Adding a second unlabeled data loader in a callback, I haven’t got to coding it up but have this terrible feeling its all going to go wrong! (Also just not sure if its the best way to handle the problem)
-
Is training one model and then the other the best way of going about things? I can’t think of a better one but there might have been something that I’ve missed