Learn.export() not working with GANs

I tried learner.save() which did not seem to recover the model on learner.load() for the GAN model. I’m not loading for inference, I’m loading to continue training and it did not restore to the point in training it had got to previously - just a mess!

My question to you was did you have any code examples of how you saved and restored a GAN learner for further training? Or do you not know how to do this?

For loading in a generator to continue training, I did this:

learn_gen = create_gen_learner(path=PATH).load(gen_old_checkpoint_name)

create_gen_learner is a wrapper function for creating a unet learner:

def create_gen_learner(path):
    """
    Creates a Fastai GAN Learner.

    Parameters:
        path: Filepath to save learner object
    """
    return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                         self_attention=True, y_range=y_range, loss_func=loss_gen, path=path)

Then, I load in the saved “checkpoint” model from this new unet learner. When you say your model doesn’t restore the saved one, are the weights not getting updated?

Ah so I am using the wgan as described in this example:

So with the wgan, the generator and critic and created for you with the learner i.e.

generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
critic = basic_critic (in_size=64, n_channels=3, n_extra_layers=1)
learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

The learn object here which the training happens on is thus the critic & generator i.e.

learn.fit(30,2e-4)

As you will see from the notebook, each epoch of this learner (critic & gan) shows a generated image, which tells you how well it is doing. By the last epoch you can see where it has got to.

This is the learner I am trying to save with:

learn.save()

Which when restored with:

learn.load()

Does not seem to recover i.e. when you continue training, you can see from the generated images at each epoch that it has lost the progress gained from earlier training (its all just pixelated mess).

So seems like I may be using a different approach to creating my wgan? but the one shown in the notebooks - and this wgan object does not seem to save state well to continue training - have you used wgan and saved state for further training to continue?

I haven’t used wgan with saving/loading models, but again, what command are you using to save/load in your models? Remember that those generated images are created from the generator model, so when you load in whatever saved model files you have, check if the weights of the last layers are being updated or not in the generator model. If the latter, then that’s your problem right there.

1 Like

Thanks for tips again - took a while to get back here!

Currently I have to save the gen learner, the critic learner, and the gan learner if I want to be able to use a GAN for inference later. Has anyone figured out a way to actually export() a gan learner for easy inference? Now I have to run quite a lot of code setting everything back up again so I can load all three models and it is not very production ready.