Implementing VAE Loss in my nn.Module class:
def vae_loss(self, predict, target): BCE = F.binary_cross_entropy(predict, target.view(-1, 784), reduction='sum') KLD = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) return BCE + KLD def forward(self, x): self.mu, self.logvar = self.encode(x) z = self.reparameterize(self.mu, self.logvar) return self.decode(z)
Training with MNIST and my loss function as shown:
learn = Learner(loaders, mdl, loss_func =mdl.vae_loss)
This works in PyTorch but I can’t get it to work in FastAI. Training loss is NaN and when I debug it appears that my loss function is never called. I try wrapping it separately in a BaseLoss class but I can’t work out how to configure it correctly, and how to use the parameters from my model.