Critic's loss not going down in lesson 7

I’ve replicated the steps Jeremy took on https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson7-superres-gan.ipynb

However, I did take a different approach where I am only using 20 images total of 2 types of dogs(10 each) as well as splitting the validation set to 30 percent of the ‘image_gen’ folder and increasing the lr to 1e-2 once I got to the pretraining of the critic part.

This gave similar results, but the below happened when I got to the pretrain the critic part. This might be more of a @jeremy question, but why is it that the critic’s loss doesn’t go down when the dataset is small? I would think it’ll be easier to distinguish between a fake and a real image, since there are less images to learn. It’s not like the MSEloss of the generator was lower than Jeremy’s notebook either. Shouldn’t this mean that the critic should have a easier time distinguishing the real from fakes?

def get_crit_data(classes, bs, size):
    src = ImageList.from_folder(path, include=classes).split_by_rand_pct(0.3, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=size)
            .databunch(bs=bs).normalize(imagenet_stats))
    data.c = 3
    return data

data_crit = get_crit_data([name_gen, 'images'], bs=4, size=size)
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
plt.show()

loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())

def create_critic_learner(data, metrics):
    return Learner(data, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd)

learn_critic = create_critic_learner(data_crit, accuracy_thresh_expand)

learn_critic.fit_one_cycle(25, 1e-2)
learn_critic.save('critic-1')

data_crit = get_crit_data(['crappy', 'images'], bs=bs, size=size)
learn_crit = create_critic_learner(data_crit, metrics=None).load('critic-1')

the output

----------------------------------GAN--------------------------------------
epoch     train_loss  valid_loss  accuracy_thresh_expand  time    
0         0.686626    0.702948    0.356364                00:10     
1         0.691571    0.719672    0.367273                00:11     
2         0.684327    0.766719    0.450909                00:11     
3         0.694961    0.740039    0.374545                00:10     
4         0.697699    0.698146    0.545455                00:10     
5         0.692488    1.168247    0.545455                00:10     
6         0.705878    0.690986    0.545455                00:10     
7         0.705758    0.716483    0.483636                00:10     
8         0.710638    0.705156    0.476364                00:10     
9         0.710643    0.693130    0.443636                00:10     
10        0.708336    0.695476    0.454545                00:10     
11        0.706554    0.691266    0.545455                00:10     
12        0.705258    0.690534    0.545455                00:10     
13        0.704851    0.697592    0.454545                00:11     
14        0.703753    0.696524    0.454545                00:11     
15        0.702643    0.694397    0.454545                00:10     
16        0.701649    0.693619    0.454545                00:10     
17        0.700480    0.693081    0.450909                00:10     
18        0.699535    0.693055    0.450909                00:11     
19        0.698590    0.692627    0.538182                00:10     
20        0.697668    0.692416    0.629091                00:10     
21        0.697229    0.692438    0.600000                00:11     
22        0.696679    0.692444    0.581818                00:10     
23        0.696413    0.692449    0.581818                00:11     
24        0.695998    0.692447    0.581818                00:10
1 Like

I have the same problem. The critic can’t learn to distinguish between 166 pairs of images, half of which are greyscale and half b/w. It also stays at 0.69, lr plot is FLAT up to 1e-02.
I see no errors in my setup, I’ve even removed all transforms.

Hi @mlnoob, did you ever figure out what the issue was? I’m having the exact same problem now.

Ohh, I have the same problem. Is it salved now?
I am doubting it is caused by gan_critic(), but I do not know how to solve it.

I tried Rensnet34, It learned how to classify, but it can not work in a GAN, maybe it did not use spectral normalization?

I tried to find a good learning rate in the lesson 7 repo, and the figure looks similar? Maybe this time we can not straightly find a good learning rate in this way?
And still not resolved…