I’m trying to implement WGAN with a gradient penalty, and I need some help picking through the way loss calculation is structured.
In the GANLoss
module, the loss calculation for the critic is defined as:
def critic(self, real_pred, input):
fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
fake_pred = self.gan_model.critic(fake)
return self.loss_funcC(real_pred, fake_pred)
The loss for the generator is defined as:
def generator(self, output, target):
fake_pred = self.gan_model.critic(output)
return self.loss_funcG(fake_pred, target, output)
So the generator receives output
, images generated by the generator, and target
, example images. The critic loss receives real_pred
, the result of passing real images through the critic, and input
, a set of latent vectors used to generate fake images.
To implement a gradient penalty, I need to pass real images and fake images to the loss function for the critic. I can’t do this with the current structure because real images are not passed to the critic loss function, only the result of running real images through the critic loss are passed in. Currently I’m trying to figure out where exactly real_pred
, the result of running real images through the critic, is calculated.
I feel like this happens in the forward pass of the model, but right now this is only defined in the GANModule
class as:
def forward(self, *args):
return self.generator(*args) if self.gen_mode else self.critic(*args)
I don’t know what *args
are or where they come from. Does anyone know where the calculation happens, or how to get the target
variable from the generator loss passed to the discriminator loss?