Multi-term losses and cyclic annealing schedule

Hi guys,

Just to bring attention to a recent observation I made. I have a multiple-term loss function, one of those is a KL divergence and the other a more traditional classification head. From a recent paper that I was studying I followed a reference to: “Cyclical Annealing Schedule: A Simple Approach to Mitigating KL Vanishing”.

While the technique has been created to deal with Variational Autoencoders I had sort of the same problem on a different environment. I plugged the cyclical schedule to the loss function and all my vanishing problems went away. But it doesnt end there, I have also a similar network that instead of optimizing over the KL Divergence it optimizes around the MSE of the distribution instead (AlphaZero style). Funny though, this trick also enabled better convergence on that network too.

Now the far fetched hypothesis. What if this trick alone can mitigate GAN mode collapse? Anyone working on the topic that could test it out?

EDIT: I am doing it on batches instead of epochs but this is how the schedule looks:

def kl_schedule(epoch, kl_annealing):
if epoch < kl_annealing:
return epoch / kl_annealing
else:
return 1

total_loss = (l_v + kl_schedule(i, kl_annealing) * l_pi)

1 Like