Conditional GANs are a mashup of a normal GAN that goes from a noise sampling -> a generated image, but also utilizes labeling to improve both the generator and discriminator. It’s often mentioned as one of the best ways to improve the quality of a GAN (@jeremy mentioned a while back that implementing a conditional WGAN would be interesting, which is ultimately where I’m going with this…)
You can read a bit more about Conditional GANs (often called “CGAN”, not to be confused with Cyclical GANs) in this medium post and here is nice, succinct implementation in Pytorch
I’m currently working on building a custom “LabeledGANItemList” based on GANItemList that would support (multi) labels, but I am getting stuck on how to do this and could use some help.
The way I see it, there’s two approaches you could take:
Extend the noise (NoisyItem) data tensors (and image data tensors) to include a one-hot dimension of the labels
Intelligently label the NoisyItem and ImageItem in a way that plays well with the rest of the DataLoader API
How to do #2 is proving to be difficult for me to wrap my head around. The DataLoader/ItemList assumes you have inputs (x) and targets (y), which correspond to the image (x) and the labels (y). In this case, we want something more like this:
Any thoughts about how to elegantly achieve this in a way that would play nicely with the rest of the DataLoader API? Here’s a gist of my work to date, which is based off ImageBBox / GANItemList, but I’m finding that I’m running into issues with how to store/label the ImageItems (since they are implicitly handled in the GANItemList’s superclass methods).
Hey @sgugger, do you have any tips or recommendations for how I could build a custom ItemList for a Conditional GAN? (see full details above) Happy to put submit a PR to the fast.ai repo w/ the class once I’ve got it working! Thanks in advance!
It’s a bit hard to understand from your gist: CGanItemList isn’t defined. In any case, you have to use no_split (or any split you like) after creating your ItemList then label it with a class that should ImageLabeledList (to go with your ImageLabeled item.
@sgugger thanks for taking the time to look over it! Sorry about the misnamed ref in the gist…
I’ve got this working thanks to your helpful suggestions about what to do. I’m now running into this warning once I create the DataBunch and call normalize() fastai/basic_data.py:260: UserWarning: It's not possible to collate samples of your dataset together in a batch.
Is that necessarily a bad thing, given that we’re not trying to match our Xs to our Ys in this case?
I’ve update the gist with the most recent code (and cleaned up the naming some more).
Figured out the problem for fastai/basic_data.py:260: UserWarning: It's not possible to collate samples of your dataset together in a batch. <- This is due to labels being mismatched.
@sgugger I discovered two methods, dls() and batch_stats(), in the DataBunch/ImageDataBunch classes which break if you have databunch.valid_ds = None
I’ve coded up fixes for both of them, would you like me to submit a PR for this to the fastai repo? (I could imagine there are good reasons not to mess with dls()…). Here’s a quick overview of what the changed methods would look like:
# Only add the valid_dl to DataLoaders if it exists, avoiding having other methods iterate over a None value when actual DataLoaders are expected. Preserve order of DataLoaders in either case
# In Databunch()...
def dls(self):
res = [self.train_dl, self.valid_dl, self.fix_dl, self.single_dl] if self.valid_dl else [self.train_dl, self.fix_dl, self.single_dl]
return res if not self.test_dl else res + [self.test_dl]
# If no Validation Dataloader exists, fallback to using the Train Dataloader
# In ImageDataBunch...
def batch_stats(self, funcs:Collection[Callable]=None)->Tensor:
"Grab a batch of the Target (y) Data and call reduction function `func` per channel"
funcs = ifnone(funcs, [torch.mean,torch.std])
ds_type = DatasetType.Valid if self.valid_dl else DatasetType.Train
x = self.one_batch(ds_type=ds_type, denorm=False)[1][0].cpu()
return [func(channel_view(x), 1) for func in funcs]
I am working in a similar project, where I need to condition the GAN with a dense vector with real values instead of a OHE, I wonder if you were able to implement the ItemList for a Conditional GAN you were working on? is there any repo I where I could check your implementation, it would be very helpful. Thanks!