Fastai v2 Variational AutoEncoder

Hi there,
I’m working on spinning up a Convolutional Variational AutoEncoder. I have proof of concept, but I’m struggling with a few things to fully integrate it into the API. I was hoping to get some guidance on overall design patterning in v2.

Specifically:

  • Passing tuples of tensors versus tensors for models .forward() functions. This causes problems when we get to the APIs .pred() which expects tensors. e.g. the Loss functions need access to the “mu” and “logvar” layers for the KL Divergence calculation. The typical pattern would be just to pass it as output and then into the loss function.
  • I have experimented with adding and removing them with callbacks but this feels pretty hack-ey and also breaks other callbacks such as the crucial .to_fp16().

Would creating a new data class be the right way to do it? Right now I’m simply using Imageblocks. i.e. blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImage

Creating a totally new Learner subclass seems unwarranted.

Its been a great experience so far getting acquainted with the API which is incredibly powerful, but some of the deeper logic is hard to decipher.

thanks!
Andy

1 Like

Hi Andy, have you seen this tutorial?: https://docs.fast.ai/tutorial.siamese.html

It also a tuple of images as input.

Thanks, Darek

Hi Derek,

Thanks for the reply. Yes I am familiar with that tutorial, and have been thinking that creating a similar class for the output block would work…

something like:

class VAE_OUTTuple(fastuple):
    @classmethod
    def create(cls, fns): 
        #Y,z,mu,logvar = fns
        return cls(tuple(fns))
    
    def show(self, ctx=None, **kwargs): 
        Y,z,mu,logvar = self
        if not isinstance(Y, Tensor) or not isinstance(z, Tensor) or not isinstance(mu, Tensor)or not isinstance(logvar, Tensor) : return ctx
        # TODO: make fn show z,mu,logvar
        return show_image(Y, dim=2), ctx=ctx, **kwargs)

vaeblock = DataBlock(blocks=(ImageBlock(cls=PILImage), VAE_OUTTuple),
                  get_x=get_x, 
                  get_y=get_y,
                  splitter=ColSplitter('is_valid'), 
                  item_tfms= Resize(IMG_SIZE,method='pad', pad_mode='border'),
                  batch_tfms = btfms,
            )
    

But I’m not sure if that’s the best approach, or whether keeping the forward() function clean and passing just a single tensor.
e.g. I also experimented with the callback patterns from the TabularAE from @etremblay and @muellerzr ( https://github.com/EtienneT/TabularVAE/blob/master/TabularAE.ipynb ) to “inject” the additional parameters. Their annealing kl_weight parameter injected via callback is particularly crucial.

I’m really looking for some guidance about whats the better overall solution for the fastai API. Eventually I will want to parameterize the latent space to generate mixed outputs.

thanks again for your input,
Andy

If you are interested I had been working locally on a toy conditional convolutional VAE with MNIST. So basically a convolutional VAE, but I also pass in the label to the model. So later on you can ask the VAE for a specific number. You pass the decoder 2 and a randomly generated sample from the unit gaussian and it generates a 2. Everything uses fastai, so if it can help you. Nothing is documented though and it’s been a while since I have worked on that, but if it can help you.

I tried other things too. For example using MMD instead of KL in the loss. You can check that here:

For CNN vae, MMD didn’t work well for some reason, but for tabular data vae I had some good results. I just commited it here:

3 Likes

Thanks Etienne, I’ll have a look.

Essentially I think that your MNIST toy does what I am trying for. I want to pass a latent vector and have it “sample” from it to generate an image. Making some traversals in the latent space between two examples as well as mixtures is the goal.

1 Like

I just had a look. I can’t wait to try out replacing KLD with MME!!

Your TablularVAE code has already been incredibly helpful. I just keep going back and forth between thinking I should customize the Learner and doing everything with Hooks and Callbacks…

My project was borne out of an experiment at using notebooks, nbdev, and fastai. So far I have stayed true to the experiment and haven’t opened VSCode, but I’ve struggled at times with making modules from notebooks and especially the auto-generated documentation.

I’m working on an expository notebook of my code which I’m excited to share, and then I can take stock of how the development environment experiment pans out.

Thanks again!

2 Likes

MMD is significantly slower than KL divergence though. It worked well for tabular VAE, but not so much for my CNN experiment, but your millage may vary. Glad my code could be useful! Keep us updated with what you find out.

1 Like

Hi @etremblay! I stumbled upon this tread as someone who was fairly well-versed in v1 but is only now getting into v2. I am also trying to build a convolutional VAE, and I want to thank you for sharing your example - it’s a huge help. However, upon downloading and running the notebook, I stumbled upon a couple of issues, and I was hoping you might be able to help out with those.

The first error I encountered was the following:

RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are
safe to autocast.

I swapped the function to binary_cross_entropy_with_logits in the BCE term and the loss function definition.

The second error I encountered was the following:

TypeError: Object of type TensorImageTarget is not JSON serializable

It seems to come from the W&B callback, so I commented that out. This at least allowed me to train the model.

Finally, when trying to use the plotting code after training, I receive the following error:

TypeError: no implementation found for ‘torch.cat’ on types that implement torch_function: [TensorImageBWNoised, TensorImageTarget]

TensorImageBWNoised and TensorImageTarget seem to be problematic classes. Based on the definitions of those classes at the beginning of the notebook, as I understand it, they are virtually the sae as the superclass TensorImageBW - so it seems like torch.cat should be able to run on those classes?

I hope my questions make sense. Your example is very valuable and I would like to be able to reproduce your notebook without errors before I start modifying it, but it seems like there’s something I am missing. Any help is appreciated!

This seems to be new errors since the last time I touched this project.

If you remove the to_fp16, then it seems to remove the autocast error… But this is a shame since we want fp16… But I don’t know why binary_cross_entropy does not support autocast. So just remove this line to fix this error: learn = learn.to_fp16().

Yeah comment out the W&B callback :slight_smile:

For the last error, this seems something new with PyTorch 1.7, I think it supports tensor subclassing now… Not sure how to fix it from my quick google searches. Maybe @muellerzr would know for this particular error?

TypeError: no implementation found for ‘torch.cat’ on types that implement torch_function: [TensorImageBWNoised, TensorImageTarget]

Those tensor class are super dumb:
class TensorImageBWNoised(TensorImageBW): pass
class TensorImageTarget(TensorImageBW): pass

They are mainly used for type dispatching. I guess we would need to cast them to simple tensors because calling torch.cat, but not sure how to do that and google is not helpful…

1 Like

Alright, asked on the discord server about the last error and @muellerzr came to the rescue! Basically since Pytorch 1.8, pytorch is much more explicit about tensor types. So here to fix this, you can simply create new TensorBase from the specific tensor subclasses we are trying to concat:

img = torch.cat([TensorBase(input_img.clamp(0,1)), TensorBase(targs), imgs], dim=3)[i].squeeze(0).cpu()

This seem to do the trick. I pushed the changes to github.

2 Likes

That makes sense! Thanks so much to both of you for your help!

Edit: One last question - what GPU are you using to run your notebook? I’ve tried Colab and Kaggle with the free GPUs they provide respectively and it takes about 1 minute and a half per epoch, whereas in your example it was 25ish seconds per epoch. I was wondering if the difference in training time is just down to a more powerful GPU?