class CVAE(nn.Module):
def init(self, latent_dim=64, s=seq_length):
super(CVAE, self).init()
self.s = s
# Assume s is the full size and we want to pool it to s//8
self.target_pool_size = (s//8, s//8)
# Initial transformation of the sequence embedding and delta
self.seq_emb_transform = nn.Conv2d(1280, 64, kernel_size=1)
self.delta_transform = nn.Conv2d(1, 64, kernel_size=1)
self.seq_w = nn.Parameter(torch.randn(1280, 64))
self.delta_w = nn.Parameter(torch.randn(1, 64))
# Encoder layers
self.encoder = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=1), # input channel = 1 for distance matrix + 64 transformed
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
)
# Latent space to shape (latent_dim)
self.fc_mu = nn.Linear(16 * (s // 8) * (s // 8), latent_dim)
self.fc_logvar = nn.Linear(16 * (s // 8) * (s // 8), latent_dim)
# Decoder layers
self.fc3 = nn.Linear(latent_dim, 64 * (s // 8) * (s // 8))
self.adaptive_pool = nn.AdaptiveAvgPool2d(self.target_pool_size)
self.shared_upconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # (s//8) to (s//4)
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # (s//4) to (s//2)
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # (s//2) to s
nn.BatchNorm2d(16),
nn.ReLU(),
)
self.mu_predict_conv = nn.Conv2d(16, 1, kernel_size=3, padding=1)
self.logvar_predict_conv = nn.Conv2d(16, 1, kernel_size=3, padding=1)
def encode(self, combined_input):
x = self.encoder(combined_input)
x = torch.flatten(x, start_dim=1)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, seq_embedding):
seq_embedding_transformed_pooled = self.adaptive_pool(seq_embedding) # torch.Size([10, 64, s // 8, s // 8])
z_upscaled = F.relu(self.fc3(z)) #z_upscaled:torch.Size([10, 4096])
z_reshaped = z_upscaled.view(-1, 64, self.s // 8, self.s // 8) # z_reshaped:torch.Size([10, 64, s // 8, s // 8])
combined_input = torch.cat([z_reshaped, seq_embedding_transformed_pooled], dim=1) # [10, 128, s // 8, s // 8]
x = self.shared_upconv(combined_input)
mu_pred = F.relu(self.mu_predict_conv(x))
logvar_pred = self.logvar_predict_conv(x)
logvar_pred = logvar_pred.clamp(min=-3, max=3)
return mu_pred.squeeze(1), logvar_pred.squeeze(1)
def forward(self, distance_matrix, seq_embedding):
seq_embedding_transformed = torch.einsum('bijk,kl->bijl', seq_embedding, self.seq_w).permute(0, 3, 1, 2)
distance_matrix_transformed = distance_matrix.unsqueeze(1).permute(0, 2, 3, 1)
distance_matrix_transformed = torch.einsum('bijk,kl->bijl', distance_matrix_transformed, self.delta_w).permute(0, 3, 1, 2)
seq_embedding_transformed = seq_embedding.permute(0, 3, 1, 2) #bs, s, s, 1280 ->bs, 1280, s, s
seq_embedding_transformed = self.seq_emb_transform(seq_embedding_transformed)#bs, 1280, s, s ->bs, 64, s, s
distance_matrix_transformed = self.delta_transform(distance_matrix.unsqueeze(1))#bs, 1, s, s ->bs, 64, s, s
combined_input = torch.cat([distance_matrix_transformed, seq_embedding_transformed], dim=1)#bs, 128, s, s
mu, logvar = self.encode(combined_input)
z = self.reparameterize(mu, logvar)
mu_pred, logvar_pred = self.decode(z, seq_embedding_transformed)
return mu_pred, logvar_pred, mu, logvar
def vae_loss(x, mask, mu_pred, logvar_pred, mu, logvar, beta = 0.1):
# Reconstruction loss using negative log-likelihood for a normal distribution
nll_loss = 0.5 * torch.sum(logvar_pred + (x - mu_pred)**2 / (logvar_pred.exp())
# KL divergence between the approximate posterior q(z|x) and the prior p(z)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return nll_loss + beta * kl_loss