Fastai v2 Variational AutoEncoder

Hi there,
I’m working on spinning up a Convolutional Variational AutoEncoder. I have proof of concept, but I’m struggling with a few things to fully integrate it into the API. I was hoping to get some guidance on overall design patterning in v2.


  • Passing tuples of tensors versus tensors for models .forward() functions. This causes problems when we get to the APIs .pred() which expects tensors. e.g. the Loss functions need access to the “mu” and “logvar” layers for the KL Divergence calculation. The typical pattern would be just to pass it as output and then into the loss function.
  • I have experimented with adding and removing them with callbacks but this feels pretty hack-ey and also breaks other callbacks such as the crucial .to_fp16().

Would creating a new data class be the right way to do it? Right now I’m simply using Imageblocks. i.e. blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImage

Creating a totally new Learner subclass seems unwarranted.

Its been a great experience so far getting acquainted with the API which is incredibly powerful, but some of the deeper logic is hard to decipher.


Hi Andy, have you seen this tutorial?:

It also a tuple of images as input.

Thanks, Darek

Hi Derek,

Thanks for the reply. Yes I am familiar with that tutorial, and have been thinking that creating a similar class for the output block would work…

something like:

class VAE_OUTTuple(fastuple):
    def create(cls, fns): 
        #Y,z,mu,logvar = fns
        return cls(tuple(fns))
    def show(self, ctx=None, **kwargs): 
        Y,z,mu,logvar = self
        if not isinstance(Y, Tensor) or not isinstance(z, Tensor) or not isinstance(mu, Tensor)or not isinstance(logvar, Tensor) : return ctx
        # TODO: make fn show z,mu,logvar
        return show_image(Y, dim=2), ctx=ctx, **kwargs)

vaeblock = DataBlock(blocks=(ImageBlock(cls=PILImage), VAE_OUTTuple),
                  item_tfms= Resize(IMG_SIZE,method='pad', pad_mode='border'),
                  batch_tfms = btfms,

But I’m not sure if that’s the best approach, or whether keeping the forward() function clean and passing just a single tensor.
e.g. I also experimented with the callback patterns from the TabularAE from @etremblay and @muellerzr ( ) to “inject” the additional parameters. Their annealing kl_weight parameter injected via callback is particularly crucial.

I’m really looking for some guidance about whats the better overall solution for the fastai API. Eventually I will want to parameterize the latent space to generate mixed outputs.

thanks again for your input,

If you are interested I had been working locally on a toy conditional convolutional VAE with MNIST. So basically a convolutional VAE, but I also pass in the label to the model. So later on you can ask the VAE for a specific number. You pass the decoder 2 and a randomly generated sample from the unit gaussian and it generates a 2. Everything uses fastai, so if it can help you. Nothing is documented though and it’s been a while since I have worked on that, but if it can help you.

I tried other things too. For example using MMD instead of KL in the loss. You can check that here:

For CNN vae, MMD didn’t work well for some reason, but for tabular data vae I had some good results. I just commited it here:


Thanks Etienne, I’ll have a look.

Essentially I think that your MNIST toy does what I am trying for. I want to pass a latent vector and have it “sample” from it to generate an image. Making some traversals in the latent space between two examples as well as mixtures is the goal.

1 Like

I just had a look. I can’t wait to try out replacing KLD with MME!!

Your TablularVAE code has already been incredibly helpful. I just keep going back and forth between thinking I should customize the Learner and doing everything with Hooks and Callbacks…

My project was borne out of an experiment at using notebooks, nbdev, and fastai. So far I have stayed true to the experiment and haven’t opened VSCode, but I’ve struggled at times with making modules from notebooks and especially the auto-generated documentation.

I’m working on an expository notebook of my code which I’m excited to share, and then I can take stock of how the development environment experiment pans out.

Thanks again!


MMD is significantly slower than KL divergence though. It worked well for tabular VAE, but not so much for my CNN experiment, but your millage may vary. Glad my code could be useful! Keep us updated with what you find out.

1 Like