Learn.export() not working with GANs

I’m actually seeing the recursion issue when you fit this model as well.

Not sure where this is happening but this method in OptimWrapper keeps getting called because k = ‘opt’ and there is no opt defined in OptimWrapper and sooooo when it hits this line return getattr(self.opt, k, None) it tries again to find opt and so you get stuck in an infinite loop.

#Passthrough to the inner opt
def __getattr__(self, k:str)->Any:
    return getattr(self.opt, k, None)

I’m not familiar with GANs enough to say where this is happening and how to fix … but that is the problem.

This is always where the infinite recursion happen and I don’t really know why since there is an opt in OptimWrapper.

Its something in here: GANDiscriminativeLR

I removed this line:

learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

and it fit ran no problem.

Of course, replacing is probably not the right thing … but it does isolate the problem.

Ok this is weirder now … I’m wondering if some kind of “race condition” is being introduced somewhere.

If I set some break points in GANDiscriminativeLR and step through the code … it runs fine.

I take the breakpoints out and run again … and it runs fine.

Here is the code I run each time before fitting to create the learner:

switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd, path=PATH)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

I’m totally confused now :slight_smile:

Ok again …

It looks like the recursion problem ONLY happens after you load a saved model.

learn.save(f'{prefix}gan-1c')
learn.load(f'{prefix}gan-1c');

When I call load I get this error:

/home/wgilliam/anaconda3/envs/fastai-ohmeow/lib/python3.7/site-packages/torch/serialization.py:251: UserWarning: Couldn't retrieve source code for container of type testGANLoss. It won't be checked for correctness upon loading.
  "type " + obj.__name__ + ". It won't be checked "
/home/wgilliam/development/projects/fastai/fastai/basic_train.py:324: UserWarning: Wasn't able to properly load the optimizer state again.
  except: warn("Wasn't able to properly load the optimizer state again.")

That last one about not being able to load the optimizer state seems suspect.

UPDATED:

Yah I just saw that … so very strange.

This is happening only after loading a saved GANLearner … almost like there is something there that is just removing the attribute entirely.

learn.load("gan-1") or learn.load("gan-1", purge=False) don’t work. The former gives the error AttributeError: Can't pickle local object 'gan_loss_from_func.<locals>._loss_G'
while the latter gives the error AttributeError: 'GANLearner' object has no attribute 'gen_mode' on learn.show_results() or any another inference method.

No this isn’t supposed to work as no one has implemented or tested it.
You should save your model and use it directly for inference.

Then how can I load it again after shutting down my notebook ?

You can save your model with Learner.save or directly use the torch.save function, then load it again with Learner.load.

Hi everyone, has anybody find a temporary solution to load weights into a new GANLearner or any clue that might help ? I am trying to export my model and make a clear inference script using it. Thanks in advance !

Since the GAN learner is just switching from training the generator and critic over and over again, for inference you can just load in the most recent generator weights as that is what is being used to actually give you new generated images. Do a learn.load() on the generator learner object, not the GAN learner object.

You can verify that the weights are actually changing and being updated by checking the last layer weights of the generator before and after GAN training.

1 Like

Thank you very much for your clear explanation. It’s clearer to me now, I didn’t know that the updates on weights during the GAN training were actually updating directly the generator and critic objects and so were in place.

Could you show an example of from a gan learner how you can save out the generatic and critic models separately? I can’t seem to find any examples!

I’m just trying to train my gan for a while, save its state, then reload another time and continue training - but when I use save & load on the gan learner itsself - when training continues on the loaded model it seems to have lost everything it learnt?

Are you just calling a learner.export() on the generator and the critic learner objects? How are you loading them back in?

I tried learner.save() which did not seem to recover the model on learner.load() for the GAN model. I’m not loading for inference, I’m loading to continue training and it did not restore to the point in training it had got to previously - just a mess!

My question to you was did you have any code examples of how you saved and restored a GAN learner for further training? Or do you not know how to do this?

For loading in a generator to continue training, I did this:

learn_gen = create_gen_learner(path=PATH).load(gen_old_checkpoint_name)

create_gen_learner is a wrapper function for creating a unet learner:

def create_gen_learner(path):
    """
    Creates a Fastai GAN Learner.

    Parameters:
        path: Filepath to save learner object
    """
    return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                         self_attention=True, y_range=y_range, loss_func=loss_gen, path=path)

Then, I load in the saved “checkpoint” model from this new unet learner. When you say your model doesn’t restore the saved one, are the weights not getting updated?

Ah so I am using the wgan as described in this example:

So with the wgan, the generator and critic and created for you with the learner i.e.

generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
critic = basic_critic (in_size=64, n_channels=3, n_extra_layers=1)
learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

The learn object here which the training happens on is thus the critic & generator i.e.

learn.fit(30,2e-4)

As you will see from the notebook, each epoch of this learner (critic & gan) shows a generated image, which tells you how well it is doing. By the last epoch you can see where it has got to.

This is the learner I am trying to save with:

learn.save()

Which when restored with:

learn.load()

Does not seem to recover i.e. when you continue training, you can see from the generated images at each epoch that it has lost the progress gained from earlier training (its all just pixelated mess).

So seems like I may be using a different approach to creating my wgan? but the one shown in the notebooks - and this wgan object does not seem to save state well to continue training - have you used wgan and saved state for further training to continue?

I haven’t used wgan with saving/loading models, but again, what command are you using to save/load in your models? Remember that those generated images are created from the generator model, so when you load in whatever saved model files you have, check if the weights of the last layers are being updated or not in the generator model. If the latter, then that’s your problem right there.

1 Like

Thanks for tips again - took a while to get back here!

Currently I have to save the gen learner, the critic learner, and the gan learner if I want to be able to use a GAN for inference later. Has anyone figured out a way to actually export() a gan learner for easy inference? Now I have to run quite a lot of code setting everything back up again so I can load all three models and it is not very production ready.