Lesson 5: implementing DIY nn.Linear

Hey guys!

While implementing my own NN.linear class as Jeremy asked in Lesson 5 along with the single layered Mnist_Logistics model, I had sometimes total failures of the trainning, HUGE losses and the loss jumping up and down like crazy. Even while playing a bit with the Lr.
I went to see Pytorch’s nn.Linear class source code (https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear)
and my class was mostily the same except for a ‘Reset_parameters(self)’ function and a call to it at the end of init.

So I copyed it in my own NNlinear Class and now it performs as well as using nn.Linear.

Batch / Loss plot

So my question is: What is ‘reset_parameters’ doing?

def reset_parameters(self):
    init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    if self.bias is not None:
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.bias, -bound, bound)

I can understand that it’s modifying the Weights values with some kind of fancy uniform func.
But why does this makes all the difference.
Maybe I was living in a dream world where parameters could be initialized to ANY value.