Implementing VAE Loss in my nn.Module class:
def vae_loss(self, predict, target):
BCE = F.binary_cross_entropy(predict, target.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + self.logvar -
self.mu.pow(2) - self.logvar.exp())
return BCE + KLD
def forward(self, x):
self.mu, self.logvar = self.encode(x)
z = self.reparameterize(self.mu, self.logvar)
return self.decode(z)
Training with MNIST and my loss function as shown:
learn = Learner(loaders, mdl, loss_func =mdl.vae_loss)
This works in PyTorch but I can’t get it to work in FastAI. Training loss is NaN and when I debug it appears that my loss function is never called. I try wrapping it separately in a BaseLoss class but I can’t work out how to configure it correctly, and how to use the parameters from my model.