Fastai V2 WGAN loss functions flipped

It seems to me that in the vision.gan module, the loss functions that implement a wgan are flipped. Below are the generator and critic loss functions (respectively).

def _tk_mean(fake_pred, output, target): return fake_pred.mean()
def _tk_diff(real_pred, fake_pred): return real_pred.mean() - fake_pred.mean()

In WGAN, the expectation of the critic outputs on the true distribution (real_pred.mean) minus the outputs on the generated distribution (fake_pred.mean) is an approximation of the Wasserstein distance, which we need to ascend (with respect to the critic model). The generator should then minimize this estimated Wasserstein distance with respect to its parameters, meaning descending -fake_pred.mean.

Unless I’m missing something, the two expressions are flipped. Since we minimize with respect to the loss, the critic loss should be -1*_tk_diff, and same with the generator loss.

Please let me know if I’m mistaken.

I realize this question is pretty old so you may have already gained the necessary understanding at this point. However, in case someone else stumbles across this in the future and has the same confusion, your suggestion of negating both the critic loss and generator loss actually has the equivalent effect of fastai’s current implementation. With the current implementation, the critic’s loss is minimized by giving low scores to real data (small real_pred.mean()) and high scores to fake/generated data (large(r) fake_pred.mean()). The generator’s loss is minimized when the generator generates data such that fake_pred.mean() is small, in other words, the critic believes it’s data from the real distribution and thus gives it a low score as it has learned to do for real data.

Negating the output of the current implementation is equivalent to using the following functions:

def _tk_mean(fake_pred, output, target):
    return -fake_pred.mean()

def _tk_diff(real_pred, fake_pred):
    return fake_pred.mean() - real_pred.mean()

Now, in this scenario, the critic’s loss is minimized when fake_pred.mean() is small and real_pred.mean() is large(r). So the critic learns to give low scores to fake data and high scores to real data. Therefore, the generator must learn to generate data that the critic gives a high score to since the goal of the generator is to produce data similar to those from the real distribution. However, recall that we want to minimize the loss of the generator. That is why we must take the negative of the critic’s output so that as the generator produces more realistic data, which get higher and higher scores, its loss becomes more negative (smaller).

Thus, in both implementations, we’re minimizing the loss for both the critic and the generator.