Understanding GANLoss Module

I’m trying to implement WGAN with a gradient penalty, and I need some help picking through the way loss calculation is structured.

In the GANLoss module, the loss calculation for the critic is defined as:

def critic(self, real_pred, input):
    fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
    fake_pred = self.gan_model.critic(fake)
    return self.loss_funcC(real_pred, fake_pred)

The loss for the generator is defined as:

def generator(self, output, target):
    fake_pred = self.gan_model.critic(output)
    return self.loss_funcG(fake_pred, target, output)

So the generator receives output, images generated by the generator, and target, example images. The critic loss receives real_pred, the result of passing real images through the critic, and input, a set of latent vectors used to generate fake images.

To implement a gradient penalty, I need to pass real images and fake images to the loss function for the critic. I can’t do this with the current structure because real images are not passed to the critic loss function, only the result of running real images through the critic loss are passed in. Currently I’m trying to figure out where exactly real_pred, the result of running real images through the critic, is calculated.

I feel like this happens in the forward pass of the model, but right now this is only defined in the GANModule class as:

def forward(self, *args):
    return self.generator(*args) if self.gen_mode else self.critic(*args)

I don’t know what *args are or where they come from. Does anyone know where the calculation happens, or how to get the target variable from the generator loss passed to the discriminator loss?

Also a more general question on GAN loss functions.

In the fastai library, we use NoopLoss for the generator loss and WassersteinLoss for the critic.

NoopLoss returns critic(generator(z)).mean()

WassersteinLoss returns critic(x).mean() - critic(generator(z)).mean()

In the WGAN paper, they list the following algorithm:

For critic loss, they use critic(x).mean() - critic(generator(z)).mean(), same as fast.ai. However for the generator, they use -critic(generator(z)).mean(), which is the opposite of NoopLoss.

Edit: I see in the gradient update for the critic they add the gradient term in the update instead of subtracting. Does this mean everything is flipped?

In the WGAN-GP paper, they use another variant:

For critic loss, WGAN-GP uses critic(generator(z)) - critic(x) + gradient_penalty. The orientation is flipped (fake - real instead of real - fake). The generator update is also -critic(generator(z)).mean().

Does anyone have perspective on the different loss variants and the implications of them for the optimization process?

1 Like

Yes, they flipped their update which I thought wasn’t very natural. Hence the other way round in fastai.

For your other question, I think you’ll probably need to write your custom GANTrainer for this.

I am not (yet) familiar with Wasserstein-GANs, but there is a very well made depthfirstlearning.com curriculum on WGANs which maybe can help you.

(For general GAN training I can highly recommend the NIPS 2016 Tutorial on GANs by Ian Goodfellow as a starting point, especially the part around figure 16 on p. 26 and chapter 4 starting at p. 30.)

1 Like

Sorry to hijack the thread. But I have a question. From my understanding, the way fastai do it is that: the WassersteinLoss = critic(real) - critic(fake) and NoopLoss = critic(fake).

But if these loss functions are minimized, don’t they encourage the critic to output 0 for real images and 1 for fake images? In fact, the original WGAN paper proposed to maximize the WassersteinLoss, hence the + sign in the gradient update.

I was confused by this. Can you please explain to me if I’m wrong?

You are correct, and I’m confused about the same thing. Other implementations do it the opposite of fastai (both losses are flipped).

It seems that the Wasserstein distance is symmetric though, so I guess the two are equivalent.