Passing arguments from model to the loss function

I’m working on a variational autoencoder which requires passing variables from model.forward to the loss function. I can write the code in PyTorch like this

recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()

Previously I implemented a traditional autoencoder by simply passing it to the Learner class, but now it feels like I need to rewrite everything starting from the train_epoch method. Any ideas on how to tackle this?

You should use a Callback to do this :slight_smile: