Intuition Behind KL-Divergence Regularization in VAE's


(Vincent Marron) #1

While GAN’s seem to be overtaking VAE’s as the leading class of generative model, I’m still struggling to catch up and fully understand the mechanism behind VAE’s before I get started with GAN’s. If you’re new to VAE’s, these tutorials applied to MNIST data helped me understand the encoding/decoding engines, latent space arithmetic potential, etc:

  1. Miriam Shiffman, code in Tensorflow http://blog.fastforwardlabs.com/2016/08/12/introducing-variational-autoencoders-in-prose-and.html

  2. Francois Chollet, code in Keras https://blog.keras.io/building-autoencoders-in-keras.html

The thing that I can’t get my head fully around is the use of the Gaussian KL-Divergence term in the overall cost function. From Francois’ code:

> def vae_loss(x, x_decoded_mean):
>     xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
>     kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
>     return xent_loss + kl_loss

I think I understand that the purpose of the kl_loss term is to ensure that the encoded (or latent-space) variables are an efficient descriptor of the input (in this instance, that the model efficiently utilizes its allotted two unit gaussians as best as possible to describe the set of hand-written numerals). What I can’t understand is the intuition behind the derivation of the kl_loss function applied to the latent variables… it seems to want to reduce both the realized z_mean and z_log_sigma.

The appendices of what I believe to be the original VAE paper ( https://arxiv.org/pdf/1312.6114.pdf ) include a formal derivation but I’ve been struggling to get my head around it… Is there anyone with a deeper math/statistics background to whom this is intuitive? Perhaps @rachel ?

Thanks a lot and I apologize if this is a distraction from the main course material… I had been struggling with it before the course started and figured someone here might know whats going on.


#2

Ahh VAEs :slight_smile: Spent 3 lovely days beying utterly confused by how this works during hackerweek at work :smiley:

I can’t comment about the math as the theory behind variational learning is something that I do not understand, but once you implement the concept turns out to be extremely simple!

At the heart of the VAE you have a standard normal distribution, or at least you want to have a standard normal. Why? If you let this be a normal distribution with any params, the algorithm can effectively ‘cheat’ - it can become very picky about the std dev and mean values it might want for a particular input, which is counter to the idea of regularization and minimum description length I guess. So we want to somehow tell it - hey algo, learn how to put the standard normal in your center and learn how you can leverage values that it produces for representing the input vector.

The KL divergence is just some fancy thing from information theory that consists of simple math that can be differentiated. It also just happens to be able to indicate how far off the distribution you are getting is away from a standard normal. Don’t remember now - would need to probably read on that again - but it has some nice theoretical properties.

But really it doesn’t matter :slight_smile: It is just some error measure, some measure of difference between the standard normal and the distribution you are getting. Like an L2 or L1 for normal distribution :slight_smile: So we get to penalize our algorithm for straying away from the 0 mean 1 std dev and we get to differentiate that :slight_smile:

If you would be interested in reading a not so serious write up I did on my experiences with VAE, here it is. Don’t think it is accurate in any mathematical sense but maybe you’ll find it useful.

I found this tutorial super helpful (and the code that came with it) for sort of getting my head around how this works.


(Jeremy Howard) #3

Note that KL divergence and cross-entropy loss (which we used throughout part 1) are very nearly the same thing. It just has one extra component, which you can ignore anyway since it’s constant in the optimization.

I tossed around the idea of teaching VAE’s but I don’t think I’m going to - I’m not sure that they add anything at the moment. They certainly are in no way a prerequisite to understanding GANs.

(In general, spending more time coding and running experiments is likely to give you more value than studying more mathematical background, IMHO.)


(Vincent Marron) #4

Thanks for this response. That Doersch paper helped a bit but I think you are right to say:

I spent a little time working with the variables and got to a pretty intuitive result. From the algorithm:

kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)

Let’s simplify and say we are encoding to a single latent (using symbols mean = μ, stdev = σ). This becomes:

kl_loss = - 0.5 * (1 + ln(σ) - μ^2 - σ) = -.5 + (μ^2)/2 + ( σ - ln (σ))/2

With partials (or gradient):

d(kl_loss)/dμ = μ
d(kl_loss)/dσ = (σ-1)/σ

So, looking to minimizing that kl_loss, gradient descent would work to form a network such that latent variable μ is brought to zero while σ is brought to 1… in other words - the latent space is brought to a standard normal!


(Arvind Nagaraj) #5

There’s a ton of deep Bayesian work coming out and YW Teh gave a fantastic overview at NIPS : https://youtu.be/9saauSBgmcQ

Here is a wonderful pytorch tutorial: https://blog.paperspace.com/adversarial-autoencoders-with-pytorch/

Uncertainty modeling seems useful in the deep learning context and I hope we cover some of it in p2v2 next year.


(Arvind Nagaraj) #6

I am going to add something to what I wrote above almost 6 months ago.:slight_smile:
I am seeing more and more people talk about prediction uncertainties in deep learning in real world projects- So this Bayesian DL stuff is important but can be hard to understand intuitively.

I was looking for an introductory tutorial that is close to Jeremy’s style of top-down teaching.

And I found this video: https://youtu.be/s0S6HFdPtlA

Pretty amazing talk!