Generating from CycleGAN

With help from @sgugger and @jcreinhold I managed to get a CycleGAN loading my custom data and training, using only the fastai library. (That’s kind of an objective of mine, since I’d rather focus on a single platform/library, as much as possible.) However, I’m really not sure how to go about generating an output. My application is a “sketch-to-image” type of model, so I want to provide an image from my test set (a single 64x64 image, not a pair), and just get a single “fake” image from the GAN’s generator. It seems like I need something like cycle_gan.G_B(image.data), but I haven’t been able to get anywhere with that approach.

Any tips greatly appreciated.

Indeed, you need to use either G_A or G_B. You should add a dimension to image.data to make it a batch and normalize it like the rest then put that in this model.
Note that cyclegan is under development, there might be some helper functions later on in fastai.

Ah, okay, great; I missed the batch step. I’ll keep poking around with it. Thanks!

Okay, it looks like I have the generator working now! Thinking toward production, I’d like to save just the generator for my application (of course, I don’t need the discriminator), and ultimately convert it to coreml. How can I go about saving the generator alone?

UPDATE: I see that this would just be torch (not fastai, per se) – it looks like torch.save(cycle_gan.G_B, "name") will work. However, I’m not sure how to load that again… Presumably it wouldn’t be quite the same as loading the whole CycleGAN(??)

A related question: When I save/reload the learner (learner.save("filename") and learner.load("filename")) I get the error: AttributeError: 'TargetTupleList' object has no attribute 'valid_ds'. The model does appear to function correctly after reloading, so obviously this isn’t a show-stopper, but I’m curious about the error. Is it literally as simple as writing a getter for valid_ds?

You need to save the torch model with

torch.save(cycle_gan.G_B.state_dict(), "name")

Then you can reload it by: 1. creating the model, which would be something like generator = resnet_generator(3,3) (with maybe kwargs to pass there if you changed the defaults)
2. load the model:

generator.load_state_dict(torch.load("name"))

Excellent, thanks!

Is there any recent changes of source code related to cycleGAN? the cycle GAN notebook is not producing results as expected…
Output is something like this after 20epochs


A month ago the same code gave expected good results

Not that I know of. I’d be curious to know the cause of the regression if you manage to track it.

1 Like

I think there is something to do with initialization of weights. Because when I ran the code lets say 10 times it is converging for around 6times. It may be due to other reason also.

Finally, I can conclude that sometimes the discriminator is not able to pick our intention to separate generated images from fake images. This is a very common problem with GANs (i.e., mode collapse). Why I came to this conclusion because there are many things that are in generated images which are not in real images and one of it is our desired attribute. When I tried making discriminator more complex this problem is sorted… We are encountering this problem because we are training discriminator after training generator( for every batch). Another interesting observation was, I made generator more complex with UNET and stuff and trained the model for more epochs then the generator started fooling discriminator by producing random colored dots(yellow in my case) so that it is easy for discriminator to distinguish between fake and real images(as fake/generated images have these colored dots) and in this model I got very very less discriminator loss though generator and id loss were high.Lesson learned: Proper complexity of generator and discriminator also matters in GANs.
Complexity in my words mean more deep network.
And I have also implemented UNET with cycleGAN and the results were pretty good.
Ping me for further details🙃

Thanks fastai

For my use-cases the standard PatchGAN and 9-ResBlock architectures worked well.

Have you tried training with this repository as well?:
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix