VAE protein design

Recently, I’ve embarked on a solo endeavor focused on protein design, leveraging Variational Autoencoders (VAE) for gene therapy advancements. Would anyone be willing to review my work and share their thoughts or suggestions especially in network architecture and the training process? Also anyone willing to coach me on this project?

My main concern is that after 3 epochs my kld goes inf and bce goes nan. How can i get to the root cause of it?

Working version that I hope someone will review.

HI, I am doing some similar work. I’d like to see your project

https://github.com/arjan-hada/protein-design-vae/tree/main

Your vae model structure looks good to me, I have a similar VAE implemented using stacked convolutional layers. Currently, I’m utilizing ESM embeddings to predict the 2D distance matrix.

Can you share share your code?

class CVAE(nn.Module):
def init(self, latent_dim=64, s=seq_length):
super(CVAE, self).init()

    self.s = s
    # Assume s is the full size and we want to pool it to s//8
    self.target_pool_size = (s//8, s//8)

    # Initial transformation of the sequence embedding and delta
    self.seq_emb_transform = nn.Conv2d(1280, 64, kernel_size=1)
    self.delta_transform = nn.Conv2d(1, 64, kernel_size=1)
    self.seq_w = nn.Parameter(torch.randn(1280, 64))
    self.delta_w = nn.Parameter(torch.randn(1, 64))

    # Encoder layers
    self.encoder = nn.Sequential(
        nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=1),  # input channel = 1 for distance matrix + 64 transformed
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(16),
        nn.ReLU(),
    )  

    # Latent space to shape (latent_dim)
    self.fc_mu = nn.Linear(16 * (s // 8) * (s // 8), latent_dim)
    self.fc_logvar = nn.Linear(16 * (s // 8) * (s // 8), latent_dim)

    # Decoder layers
    self.fc3 = nn.Linear(latent_dim, 64 * (s // 8) * (s // 8))
    self.adaptive_pool = nn.AdaptiveAvgPool2d(self.target_pool_size)
    self.shared_upconv = nn.Sequential(
        nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # (s//8) to (s//4)
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # (s//4) to (s//2)
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # (s//2) to s
        nn.BatchNorm2d(16),
        nn.ReLU(),
    )
    
    self.mu_predict_conv = nn.Conv2d(16, 1, kernel_size=3, padding=1)
    self.logvar_predict_conv = nn.Conv2d(16, 1, kernel_size=3, padding=1)  
    
def encode(self, combined_input):
    x = self.encoder(combined_input)
    x = torch.flatten(x, start_dim=1)
    mu = self.fc_mu(x)
    logvar = self.fc_logvar(x)
    return mu, logvar

def reparameterize(self, mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

def decode(self, z, seq_embedding):
    seq_embedding_transformed_pooled = self.adaptive_pool(seq_embedding) # torch.Size([10, 64, s // 8, s // 8])

    z_upscaled = F.relu(self.fc3(z)) #z_upscaled:torch.Size([10, 4096])
    z_reshaped = z_upscaled.view(-1, 64, self.s // 8, self.s // 8) # z_reshaped:torch.Size([10, 64, s // 8, s // 8])

    combined_input = torch.cat([z_reshaped, seq_embedding_transformed_pooled], dim=1) # [10, 128, s // 8, s // 8]
    x = self.shared_upconv(combined_input)
    mu_pred = F.relu(self.mu_predict_conv(x))
    logvar_pred = self.logvar_predict_conv(x)
    logvar_pred = logvar_pred.clamp(min=-3, max=3)
    
    return mu_pred.squeeze(1), logvar_pred.squeeze(1)

def forward(self, distance_matrix, seq_embedding):
    seq_embedding_transformed = torch.einsum('bijk,kl->bijl', seq_embedding, self.seq_w).permute(0, 3, 1, 2)
    distance_matrix_transformed = distance_matrix.unsqueeze(1).permute(0, 2, 3, 1)
    distance_matrix_transformed = torch.einsum('bijk,kl->bijl', distance_matrix_transformed, self.delta_w).permute(0, 3, 1, 2)

seq_embedding_transformed = seq_embedding.permute(0, 3, 1, 2) #bs, s, s, 1280 ->bs, 1280, s, s

seq_embedding_transformed = self.seq_emb_transform(seq_embedding_transformed)#bs, 1280, s, s ->bs, 64, s, s

distance_matrix_transformed = self.delta_transform(distance_matrix.unsqueeze(1))#bs, 1, s, s ->bs, 64, s, s

    combined_input = torch.cat([distance_matrix_transformed, seq_embedding_transformed], dim=1)#bs, 128, s, s
    
    mu, logvar = self.encode(combined_input)
    z = self.reparameterize(mu, logvar)
    mu_pred, logvar_pred = self.decode(z, seq_embedding_transformed)
    return mu_pred, logvar_pred, mu, logvar

def vae_loss(x, mask, mu_pred, logvar_pred, mu, logvar, beta = 0.1):
# Reconstruction loss using negative log-likelihood for a normal distribution
nll_loss = 0.5 * torch.sum(logvar_pred + (x - mu_pred)**2 / (logvar_pred.exp())

# KL divergence between the approximate posterior q(z|x) and the prior p(z)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

return nll_loss + beta * kl_loss