Middle layers of model having zero gradients

Hi guys,

I’m running into something interesting. I have the network as followes:

def basic_tabular_critic(in_size:int, **tab_kwargs):
    layers = [nn.Linear(in_size, 100, bias=False), 
              nn.Linear(100, 100),
              nn.Linear(100, 1),
    return nn.Sequential(*layers)

which is in essence a binary classifier using the wasserstein loss. The strange thing is the middle linear layers almost has no gradients.
In the image below the top plot is the mean of the gradients per layer and the bottom the std of the gradients per layer. Corresponding to the network definition, the linear layers are layer 0, 3 and 5.

The network does seem to learn (or at least reduce loss), so I’m wondering what’s happening with the middle layer. Thoughts?

I just read somewhere that you cannot use batchnormalization when using a gradient clipping like in WGAN-GP (which I am using). Might be part of the problem, will report back.

It seems removing the batch normalization has indeed fixed part of the problem, as can be seen below.
The top two images are gradient mean and std and bottom are weight mean and std.