Loss computes but gradient does not

class convVAE(nn.Module):
def init(self, dim_z=20):

    super(convVAE, self).__init__()

    self.cv1 = nn.Conv2d(1, 32, 3, stride=2)
    self.cv2 = nn.Conv2d(32, 64, 3, stride=2)
    self.fc31 = nn.Linear(2304, dim_z)
    self.fc32 = nn.Linear(2304, dim_z)
    self.fc4 = nn.Linear(dim_z, 2304)
    self.cv5 = nn.ConvTranspose2d(64, 32, 3, stride=2)
    self.cv6 = nn.ConvTranspose2d(32, 1, 3, stride=2, output_padding=1)

def encode(self, x):
    h1 = F.leaky_relu(self.cv1(x))
    h2 = F.leaky_relu(self.cv2(h1)).view(-1, 2304)

    return self.fc31(h2), self.fc32(h2)

def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

def decode(self, z):
    h5 = F.leaky_relu(self.fc4(z)).view(-1, 64, 6, 6)
    h6 = F.leaky_relu(self.cv5(h5))
    return torch.sigmoid(self.cv6(h6))

def forward(self, x):
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    return self.decode(z).view(-1, 784), mu, logvar

def get_loss(res,y):
    y_hat, mu, logvar = res

    BCE = F.binary_cross_entropy(
        y.view(-1, 784),
        y_hat,
        reduction='sum')

    KLD = -0.5 * torch.sum(1 + logvar -
                           mu.pow(2) - logvar.exp())

    return BCE + KLD

block = DataBlock(
blocks=(ImageBlock(cls=PILImageBW),ImageBlock(cls=PILImageBW)),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=(lambda x: x),
batch_tfms=aug_transforms(mult=2., do_flip=False))

path = untar_data(URLs.MNIST)
loaders = block.dataloaders(path/“training”,num_workers=0,bs=32)
loaders.train.show_batch(max_n=4, nrows=1)

mdl = convVAE(5)
learn = Learner(loaders, mdl, loss_func = convVAE.get_loss)
learn.fit(1, cbs=ShortEpochCallback())

This is a convolutional VAE trained on MNIST adapted from the Pytorch examples. I was trying to get it to work in FastAI. The loss computes, but the gradient does not, and all parameters become NaN after one step. I can’t work out what the cause of this is or how to fix it as the loss function and model are identical in the native implementation.