Implementing meta pseudo labels

I’m wanting to make a meta pseudo labels model but not sure what the best way to do it would be in fastai (I know that you could make the training loop in pytorch but would rather not!)

Meta pseudo labels paper

I’m not too sure how to:

  1. Take the gradient of loss of teacher mode given the student vaildation loss

(I think you would need to use the higher package but might be another way)

  1. How to manage training two models at the same time* (training two models sequentially batch by batch but in the same trainer)

Any thoughts/ advice appreciated!

1 Like

Did you ever figure it out? I’m also interested in doing this.

Sort of…
There’s now a pytorch implementation https://github.com/kekmodel/MPL-pytorch which I have been using.

edited: Having given it some thought (and inspired by the excellent @muellerzr callbacks livestream) this is what I have come up with, which is really a striped down gan module and setting the switcher to train equally

#psudocode!!
class MetaPsudoLables(Callback):

    #Studnet Model     
    def before_batch(self): 
        if self.teach_mode:
            x1,x2 = dls_unlabeld.batch

        psudolables = teacher(xb)

        old teacher loss

    def after_step(self):
        if not self.teach_mode:  
        CancelTrainException() # don't zero grads
   
    #Teacher
    def before_batch(self):
        if not self.teach_mode:    

        CancelBatchException() # bypass `do_onebatch`
    
        New student loss       # Alternate fit, no foward pass
        Move dot product
        teacher MPL loss
        total teacher loss = MPL loss + old teacher loss
       
    def switch(self):
       False if model.teacher_mode: else True 

There’s definitely things I’ve missed in the loop but the major things that are causing problems are:

  1. How to manage 2 models and 2 data loaders, I’ve looked a the GAN wrapper modules but don’t think I’ve understood them enough to get it working properly, setting individual model loss functions has been hard.

  2. Adding a second unlabeled data loader in a callback, I haven’t got to coding it up but have this terrible feeling its all going to go wrong! (Also just not sure if its the best way to handle the problem)

  3. Is training one model and then the other the best way of going about things? I can’t think of a better one but there might have been something that I’ve missed

2 Likes

what are x1/x2 in unlabelled.batch iter @lukemshepherd