Hi guys,
I’m running into something interesting. I have the network as followes:
def basic_tabular_critic(in_size:int, **tab_kwargs):
layers = [nn.Linear(in_size, 100, bias=False),
nn.BatchNorm1d(100),
nn.LeakyReLU(0.2),
nn.Linear(100, 100),
nn.LeakyReLU(0.2),
nn.Linear(100, 1),
nn.Sigmoid()
]
return nn.Sequential(*layers)
which is in essence a binary classifier using the wasserstein loss. The strange thing is the middle linear layers almost has no gradients.
In the image below the top plot is the mean of the gradients per layer and the bottom the std of the gradients per layer. Corresponding to the network definition, the linear layers are layer 0, 3 and 5.
The network does seem to learn (or at least reduce loss), so I’m wondering what’s happening with the middle layer. Thoughts?