Conditional GAN Question

I’m trying to implement a conditional GAN with the following architecture:

  • The generator is an autoencoder that takes in a modified version of the image and a class label, and tries to reconstruct the image to include that class. Specifically, the decoder takes in the class label as well as the encoding of the modified image, and tries to reconstruct the original image.
  • The discriminator is just like a standard discriminator, but also takes in a the class label, and determines the probability whether or not the reconstructed image contains an object of that class.

I constructed a DataBunch by using the create factory method on a dataset class I wrote. It has inputs being a tuple (the modified image and the class label), and the target being the original image. However, when I construct the GANLearner using GANLearner.from_learners, I’m getting an error that the discriminator is only getting the first input (either the reconstructed or the original image) and not the class label. How can I fix this?

class TrigramPatchDataset(data.Dataset):
    def __init__(self, x_patched, x, x_letters, seed):
        self.x_tfm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.x, self.x_patched, self.x_letters, self.seed = x, x_patched, x_letters, seed

    def _onehot(self, letter):
        idx = torch.tensor([letter])
        return torch.zeros(len(idx), len(self.seed)).scatter_(1, idx.unsqueeze(1), 1.)         

    def __getitem__(self, index):
        return (self.x_tfm(self.x_patched[index]).float(), self._onehot(self.x_letters[index])), self.x_tfm(self.x[index]).float()

    def __len__(self):
        return len(self.x)


def ae_loss(input, target):
  _x, dreal, dfake = input
  return F.mse_loss(_x, target) - torch.log(dfake).mean()


def d_loss(input, target):
  _x, dreal, dfake = input
  return -torch.log(dreal).mean() - torch.log(1-dfake).mean()  

tr_ds, te_ds = TrigramPatchDataset(tr_ims_patched, tr_ims, tr_letter_condition, seed), \
               TrigramPatchDataset(te_ims_patched, te_ims, te_letter_condition, seed)
db = DataBunch.create(tr_ds, te_ds, bs=16)
ae = Autoencoder(3, len(seed))
disc = Discriminator(3, len(seed))

learn_gen = Learner(db, ae, loss_func=ae_loss)
learn_crit = Learner(db, disc, loss_func=d_loss)
learner = gan.GANLearner.from_learners(learn_gen=learn_gen, learn_crit=learn_crit)
learner.lr_find()

Running lr_find gives me the error that forward() missing 1 required positional argument.

Here are the classes I’m using for the Autoencoder and Discriminator.

class ConvLayer(nn.Module):
    def __init__(self, num_inputs, num_filters, kernel_size=3, stride=1, padding=None, transpose=False, dilation=1):
        super(ConvLayer, self).__init__()
        if padding is None: padding = (kernel_size-1)//2 if transpose is not None else 0
        if transpose:
          self.layer = nn.ConvTranspose2d(num_inputs, num_filters, kernel_size=kernel_size,
                                          stride=stride, padding=padding, dilation=dilation)
        else:
          self.layer = nn.Conv2d(num_inputs, num_filters, kernel_size=kernel_size,
                                 stride=stride, padding=padding)
        nn.init.kaiming_uniform_(self.layer.weight, a=np.sqrt(5))
        self.bn_layer = nn.BatchNorm2d(num_filters)
    
    def forward(self, x):
        out = self.layer(x)
        out = F.relu(out)
        return self.bn_layer(out)
      
class Encoder(nn.Module):
    def __init__(self, num_inputs):
        super(Encoder, self).__init__()
        self.conv1 = ConvLayer(num_inputs, num_inputs*2, stride=2)
        self.conv2 = ConvLayer(num_inputs*2, num_inputs*4, stride=2)
        self.conv3 = ConvLayer(num_inputs*4, num_inputs*8, stride=2)
        self.pool = nn.AdaptiveAvgPool3d((24, 4, 8))
    
    def forward(self, x, x_c):
      out = self.conv3(self.conv2(self.conv1(x)))
      return self.pool(out)
      

class Decoder(nn.Module):
    def __init__(self, num_inputs, num_conditioning_inputs):
        super(Decoder, self).__init__()
        self.num_inputs = num_inputs
        self.fc = nn.Linear((4*8*num_inputs) + num_conditioning_inputs, 4*8*num_inputs)
        self.conv1 = ConvLayer(num_inputs, num_inputs, transpose=True, stride=2, padding=1, kernel_size=4)
        self.conv2 = ConvLayer(num_inputs, num_inputs//2, transpose=True, stride=2, padding=1, kernel_size=4)
        self.conv3 = ConvLayer(num_inputs//2, num_inputs//4, transpose=True, stride=2, padding=1, kernel_size=4)
        self.conv4 = ConvLayer(num_inputs//4, num_inputs//8, transpose=True, stride=2, padding=1, kernel_size=4)
        
    
    def forward(self, x, x_c):
      out = x.view(x.size(0), 1, -1)
      out = torch.cat([out, x_c], dim=-1)
      out = self.fc(out)
      out = out.view(out.size(0), self.num_inputs, 4, 8)
      return torch.tanh(self.conv4(self.conv3(self.conv2(self.conv1(out)))))

    
class Autoencoder(nn.Module):
    def __init__(self, num_inputs, num_conditioning_inputs):
        super(Autoencoder, self).__init__()
        self.enc = Encoder(num_inputs)
        self.dec = Decoder(num_inputs*8, num_conditioning_inputs)
        
    
    def forward(self, x, x_c):
      return self.dec(self.enc(x, x_c), x_c), x_c
    

class Discriminator(nn.Module):
    def __init__(self, num_inputs, num_conditioning_inputs):
        super(Discriminator, self).__init__()
        self.conv1 = ConvLayer(num_inputs, num_inputs*2, stride=2)
        self.conv2 = ConvLayer(num_inputs*2, num_inputs*4, stride=2)
        self.conv3 = ConvLayer(num_inputs*4, num_inputs*8, stride=2)
        self.pool = nn.AdaptiveAvgPool2d((num_inputs, num_inputs))
        self.fc = nn.Linear(num_inputs*num_inputs*num_inputs*8, num_conditioning_inputs)
        
    
    def forward(self, x, x_c):
      out = self.conv3(self.conv2(self.conv1(x)))
      out = self.pool(out)
      out = out.view(out.size(0), -1)
      out = torch.softmax(self.fc(out), dim=-1)
      out = (out * x_c.squeeze(1)).sum(dim=-1)
      return out

After doing some digging, I found that the callback handler, on batch begin, is processing the input with the effect that

  • Before processing, the input has both the modified/real image AND the class label
  • After processing, the input has only the modified/real image

How can I get around this?

The only callbacks I’ve found are:

['Recorder', 'FixedGANSwitcher', 'LRFinder']

but none of these have a defined on_batch_begin method. How can I trace where this modification is happening?

Ah, looks like I overlooked the GANTrainer, which is switching input and target, which makes sense.