Lesson 7 - Wgan - can't predict on saved model

Hi there,
I’ve been trying on the wgan in lesson 7 ie:

I wanted to save the model at a certain point then load i.e. I saved a checkpoint after training:

learn.save(dest/‘gan-model’)

then loaded into a new learner i defined with the same architecture:

learn2 = GANLearner.wgan(data, generator, critic, switch_eval=False, opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)
learn2.load(‘gan-model’)

This all works fine so far.

Then when i try and switch predict/generate images:

learn2.gan_trainer.switch(gen_mode=True)

I get the following errors:


AttributeError Traceback (most recent call last)

in () ----> 1 learn2.gan_trainer.switch(gen_mode=True)

1 frames

/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in getattr(self, k) 441 setattr(self.learn, self.cb_name, self) 442 --> 443 def getattr(self,k): return getattr(self.learn, k) 444 def setstate(self,data:Any): self.dict.update(data) 445

AttributeError: ‘GANLearner’ object has no attribute ‘opt_gen’


I then saw this posting in the forum:

And tried the method it suggested i.e.

learn2.load(‘gan-model’, purge=False)

This made no difference and I still got the same error:

AttributeError: ‘GANLearner’ object has no attribute ‘opt_gen’

Does anyone know how I should be loading and saving the Gan models properly so I can predict?

Thanks!

Hi Pranath,

First, may I suggest you quote any code block and output with the proper mark-down syntax, inside a pair of “```” or “`” ? That would make the output more readable :sweat_smile:. Thanks.

As to your question…

The short answer is self.opt_gen and self.opt_critic were created on the fly in GANTrainer.on_train_begin(), but don’t exist yet after learn2.load(‘gan-model’.... Thus when learn2.switch() refers to self.opt_gen immediately after a fresh load(), it blows up.

Take a look at GANTrainer's definition in gan.py to see what I mean.

Things you may try:

Now I’m not suggesting learn2 should be trained before used to generate image.

I played setting up self.opt_gen and opt_critic in a public method for GANTrainer, say set_opt(), then using your example:

learn2.load(‘gan-model’)
learn2.set_opt() # this is a method created from the first part of GANTrainer.on_train_begin()
learn2.gan_trainer.switch(gen_mode=True)
learn2.show_results(ds_type=DatasetType.Train, rows=5, figsize=(8,8))

voila, it shows some pictures in the notebook.

However, I’m not sure that’s the right fix either, coz’ it creates a new optimizer for the generator and the critic, losing the state of the trained opt_gen and opt_critic.

Alternatively, have you considered using Learner.export() and load_learner() instead of save() and load() (see basic_train.py) ? They seem to capture and restore the state of the callbacks (and GANTrainer is one such callback), hence may properly restore opt_gen and opt_critic in load_learner(). But I haven’t had time to try it out. Good luck.

Thanks for your reply!

I will try these out… I guess ultimately I want to be able to save a gan after training and then continue training… I have tried .save() and load() and it doesn’t seem to restore the state as it was i.e.

learn.save(dest/‘gan-model’)
learn2 = GANLearner.wgan(data, generator, critic, switch_eval=False, opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)
learn2.load(dest/‘gan-model’)
learn2.fit(5,2e-4)

But learn2’s generated images (during training at each epoch) seem way off where lean got up to before it was saved - should save/load be sufficient to save the gan model and restore to continue training?

The generator and the critic are PyTorch modules. Try torch.{save(),load()} on them ?

https://pytorch.org/tutorials/beginner/saving_loading_models.html

Also, if you’re to train learn2 again, I suspect you may need to save/restore GANTrainer.{opt_gen, opt_critic} as well… it becomes messy… :man_facepalming:

Thanks for replying again!

Yes unfortunately I can’t figure out from the documentation how to call save and load on the generator and critic models within the gan learner e.g. something like:

learn.generator.save(‘generator’)

Nothing like that seems to exist?

Correct, those features may not exist yet, and fastai welcomes contribution to missing features, I think :slight_smile: Not only would it be a fun hacking experience, but it’ll also be helpful to many users who run into similar problem.

I guess I’m encouraging you to hack/tinker around and see what works.

Have you tried the Learner.export() and load_learner() method? Perhaps need to override the Learner.save() for GANLearner to save the generator and critic properly…

Hmm yes good idea! thanks again for your thoughts.