Hi all,
I’m trying to train a GAN to generate synthetic images. This is my first foray into GANs and I may have taken Jeremy’s “just try some code” advice too far, but I can’t work out the issue here.
Here’s my code:
from fastbook import *
from google.colab import drive
import requests
import json
from fastai.vision import *
from fastai.vision.gan import *
drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive/"
base_dir = root_dir + 'Colab/GAN/'
%reload_ext autoreload
%autoreload 2
%matplotlib inline
# I've tried making batch size smaller and smaller, no effect
bs = 2
size = 32
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
get_x = generate_noise,
get_items = get_image_files,
splitter = IndexSplitter([]),
item_tfms=Resize(size, method=ResizeMethod.Crop),
batch_tfms = [*aug_transforms(), Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5]))])
dls = dblock.dataloaders(base_dir + 'Images/', bs = bs)
generator = basic_generator(out_size=size, n_channels=3, n_extra_layers=1)
critic = basic_critic(in_size=size, n_channels=3, n_extra_layers=1)
learn = GANLearner.wgan(dls, generator, critic, switch_eval=False,
opt_func = Adam, wd=0.)
learn.fit(3,2e-2)
I have 911 images. The fit progresses with the following warnings:
/usr/local/lib/python3.7/dist-packages/fastai/callback/core.py:51: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use
self.learn.generator
to avoid this warn(f"You are shadowing an attribute ({name}) that exists in the learner. Useself.learn.{name}
to avoid this") /usr/local/lib/python3.7/dist-packages/fastai/callback/core.py:51: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Useself.learn.critic
to avoid this warn(f"You are shadowing an attribute ({name}) that exists in the learner. Useself.learn.{name}
to avoid this") /usr/local/lib/python3.7/dist-packages/fastai/callback/core.py:51: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Useself.learn.gen_mode
to avoid this warn(f"You are shadowing an attribute ({name}) that exists in the learner. Useself.learn.{name}
to avoid this")
epoch train_loss valid_loss gen_loss crit_loss time 0 -0.462891 None None None 05:35 1 -0.503794 None None None 05:31 2 -0.275025 None None None 05:35
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:2800: DecompressionBombWarning: Image size (99996755 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack. DecompressionBombWarning, /usr/local/lib/python3.7/dist-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty. warn(“Your generator is empty.”) /usr/local/lib/python3.7/dist-packages/PIL/Image.py:2800: DecompressionBombWarning: Image size (99996755 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack. DecompressionBombWarning, /usr/local/lib/python3.7/dist-packages/PIL/Image.py:2800: DecompressionBombWarning: Image size (99996755 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack. DecompressionBombWarning,
Can anyone spot the problem? If I try to view results I get an error as well:
learn.gan_trainer.switch(gen_mode=True)
learn.show_results(rows=16, figsize=(8,8))
/usr/local/lib/python3.7/dist-packages/fastai/callback/core.py:51: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use
self.learn.gen_mode
to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Useself.learn.{name}
to avoid this")
ValueError Traceback (most recent call last)
in ()
1 learn.gan_trainer.switch(gen_mode=True)
----> 2 learn.show_results(rows=16, figsize=(8,8))1 frames
/usr/local/lib/python3.7/dist-packages/fastai/data/load.py in one_batch(self)
145 def to(self, device): self.device = device
146 def one_batch(self):
→ 147 if self.n is not None and len(self)==0: raise ValueError(f’This DataLoader does not contain any batches’)
148 with self.fake_l.no_multiproc(): res = first(self)
149 if hasattr(self, ‘it’): delattr(self, ‘it’)ValueError: This DataLoader does not contain any batches