Learn.export() not working with GANs

Hello,

After following the lesson 7 notebooks from the new Part 1, I wanted to make my own GAN project and create an inference learner from my GANLearner, though I noticed I can’t use the Fastai learn.export() command. This is the error I’m getting:

Since it was a Pickling error I did import dill as pickle and I still get the same error as above. I’m thinking that since the GANLearner is using 2 other learners (the generator and the critic), this why I’m having issues. I’m currently looking into a workaround to this, though maybe I’m missing some extra step(s) with creating a GAN inference learner.

Thanks,
Wayde

Oh we haven’t tried exporting those. The problem is that the loss functions returned by gan_from_loss_func appear to not be picklable. Try defining them as regular functions in your notebook/script instead?

Where are the parameters for _loss_G and _loss_C coming from? I’m trying to separate those two functions outside of gan_loss_from_func, but I’m getting caught up on assiging values to the variables fake_pred, output, target, and real_pred.

def gan_loss_from_func(loss_gen, loss_crit, weights_gen:Tuple[float,float]=None):
    "Define loss functions for a GAN from `loss_gen` and `loss_crit`."
    def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
        ones = fake_pred.new_ones(fake_pred.shape[0])
        weights_gen = ifnone(weights_gen, (1.,1.))
        return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)

    def _loss_C(real_pred, fake_pred):
        ones  = real_pred.new_ones (real_pred.shape[0])
        zeros = fake_pred.new_zeros(fake_pred.shape[0])
        return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2

    return _loss_G, _loss_C

They get passed on by the GanLoss. fake_preds is the output of the critic on the batch of fake images, output being this batch of fake images and target a batch of real images (those last two are used when you use GAN in conjunction with another loss, like in superres).

Then in the second loss, real_pred is the output of the critic on a batch of real images, fake_pred is the output of the critc on the batch of fake images.

2 Likes

Thanks for the help! I’ve gotten my GAN learner to export at this point after refactoring some of the Fastai code, though I’m running into some recursion issues in Python when I try to import my model. If there’s anything that immediately stands out, here’s my code (that I think) is causing the problem and stack trace:

class testGANLoss(GANModule):
    "Wrapper around `loss_funcC` (for the critic) and `loss_funcG` (for the generator)."
    def __init__(self, loss_funcG:Callable, loss_funcC:Callable, gan_model:GANModule,
                 loss_gen, loss_crit, weights_gen):
        super().__init__()
        self.loss_funcG,self.loss_funcC,self.gan_model = loss_funcG,loss_funcC,gan_model
        self.loss_crit, self.weights_gen = loss_crit, weights_gen

    def generator(self, output, target):
        "Evaluate the `output` with the critic then uses `self.loss_funcG` to combine it with `target`."
        fake_pred = self.gan_model.critic(output)
        return self.loss_funcG(fake_pred, target, output, self.weights_gen, self.loss_crit)

    def critic(self, real_pred, input):
        "Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.loss_funcD`."
        fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
        fake_pred = self.gan_model.critic(fake)
        return self.loss_funcC(real_pred, fake_pred, self.loss_crit)


def _loss_G(fake_pred, output, target, weights_gen, loss_crit):
        ones = fake_pred.new_ones(fake_pred.shape[0])
        weights_gen = ifnone(weights_gen, (1.,1.))
        return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)
    
def _loss_C(real_pred, fake_pred, loss_crit):
        ones  = real_pred.new_ones (real_pred.shape[0])
        zeros = fake_pred.new_zeros(fake_pred.shape[0])
        return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2

class GANLearner(Learner):
    "A `Learner` suitable for GANs."
    def __init__(self, data:DataBunch, generator:nn.Module, critic:nn.Module, gen_loss_func:LossFunction,
                 crit_loss_func:LossFunction, switcher:Callback=None, learn_gen=None, learn_crit=None, gen_first:bool=False, switch_eval:bool=True,
                 show_img:bool=True, clip:float=None, weights_gen=(1.,50.), **learn_kwargs):
        gan = GANModule(generator, critic)
        loss_func = testGANLoss(gen_loss_func, crit_loss_func, gan, learn_gen.loss_func, learn_crit.loss_func, weights_gen)
        switcher = ifnone(switcher, partial(FixedGANSwitcher, n_crit=5, n_gen=1))
        super().__init__(data, gan, loss_func=loss_func, callback_fns=[switcher], **learn_kwargs)
        trainer = GANTrainer(self, clip=clip, switch_eval=switch_eval, show_img=show_img)
        self.gan_trainer = trainer
        self.callbacks.append(trainer)
        
    @classmethod
    def from_learners(cls, learn_gen:Learner, learn_crit:Learner, switcher, **learn_kwargs):
        "Create a GAN from `learn_gen` and `learn_crit`."
        losses = _loss_G, _loss_C
        learners = learn_gen, learn_crit
        return cls(learn_gen.data, learn_gen.model, learn_crit.model, *losses, switcher, *learners, **learn_kwargs)

I have no idea. I’ll look at exporting Learner with GANs but not before the end of the course, in the meantime, I suggest you look a learn.save and learn.load.

I’m actually seeing the recursion issue when you fit this model as well.

Not sure where this is happening but this method in OptimWrapper keeps getting called because k = ‘opt’ and there is no opt defined in OptimWrapper and sooooo when it hits this line return getattr(self.opt, k, None) it tries again to find opt and so you get stuck in an infinite loop.

#Passthrough to the inner opt
def __getattr__(self, k:str)->Any:
    return getattr(self.opt, k, None)

I’m not familiar with GANs enough to say where this is happening and how to fix … but that is the problem.

This is always where the infinite recursion happen and I don’t really know why since there is an opt in OptimWrapper.

Its something in here: GANDiscriminativeLR

I removed this line:

learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

and it fit ran no problem.

Of course, replacing is probably not the right thing … but it does isolate the problem.

Ok this is weirder now … I’m wondering if some kind of “race condition” is being introduced somewhere.

If I set some break points in GANDiscriminativeLR and step through the code … it runs fine.

I take the breakpoints out and run again … and it runs fine.

Here is the code I run each time before fitting to create the learner:

switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd, path=PATH)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

I’m totally confused now :slight_smile:

Ok again …

It looks like the recursion problem ONLY happens after you load a saved model.

learn.save(f'{prefix}gan-1c')
learn.load(f'{prefix}gan-1c');

When I call load I get this error:

/home/wgilliam/anaconda3/envs/fastai-ohmeow/lib/python3.7/site-packages/torch/serialization.py:251: UserWarning: Couldn't retrieve source code for container of type testGANLoss. It won't be checked for correctness upon loading.
  "type " + obj.__name__ + ". It won't be checked "
/home/wgilliam/development/projects/fastai/fastai/basic_train.py:324: UserWarning: Wasn't able to properly load the optimizer state again.
  except: warn("Wasn't able to properly load the optimizer state again.")

That last one about not being able to load the optimizer state seems suspect.

UPDATED:

Yah I just saw that … so very strange.

This is happening only after loading a saved GANLearner … almost like there is something there that is just removing the attribute entirely.

learn.load("gan-1") or learn.load("gan-1", purge=False) don’t work. The former gives the error AttributeError: Can't pickle local object 'gan_loss_from_func.<locals>._loss_G'
while the latter gives the error AttributeError: 'GANLearner' object has no attribute 'gen_mode' on learn.show_results() or any another inference method.

No this isn’t supposed to work as no one has implemented or tested it.
You should save your model and use it directly for inference.

Then how can I load it again after shutting down my notebook ?

You can save your model with Learner.save or directly use the torch.save function, then load it again with Learner.load.

Hi everyone, has anybody find a temporary solution to load weights into a new GANLearner or any clue that might help ? I am trying to export my model and make a clear inference script using it. Thanks in advance !

Since the GAN learner is just switching from training the generator and critic over and over again, for inference you can just load in the most recent generator weights as that is what is being used to actually give you new generated images. Do a learn.load() on the generator learner object, not the GAN learner object.

You can verify that the weights are actually changing and being updated by checking the last layer weights of the generator before and after GAN training.

1 Like

Thank you very much for your clear explanation. It’s clearer to me now, I didn’t know that the updates on weights during the GAN training were actually updating directly the generator and critic objects and so were in place.

Could you show an example of from a gan learner how you can save out the generatic and critic models separately? I can’t seem to find any examples!

I’m just trying to train my gan for a while, save its state, then reload another time and continue training - but when I use save & load on the gan learner itsself - when training continues on the loaded model it seems to have lost everything it learnt?

Are you just calling a learner.export() on the generator and the critic learner objects? How are you loading them back in?