In the part - 2 lectures Jeremy said that initialization has a great impact in a model’s training. But the gain ‘a’ is a parameter that can be changed…I guess…other than sqrt(5).So i thought of training a model which finds the best ‘a’ for a tensor shape…but the backward computation graph is broken any help is appreciated…here’s my code.

My Training Set:
init_train = []
    for i in range(6, 513):
      for j in range(i+2, i+515):
        for k in range(3, 8):
          assert i != j
          lst = [i, j, k]

My Labels Set:
init_lbl = torch.tensor([[0, 1]]*init_train.shape[0])

My Model:

class Gain(nn.Module):
  def forward(self, a):
    return math.sqrt((1 + a**2)/2)

class Kaiming_init(nn.Module):
  def __init__(self, b = math.sqrt(3.), fan_out = False):
    self.gain = Gain()
    self.b = b
    self.fan_out = fan_out
  def forward(self, x, a):
    nf, ni, *_ = x.shape
    rec_fs = x[0, 0].numel()
    inp = nf*rec_fs if self.fan_out else ni*rec_fs
    inp = self.gain(a)/math.sqrt(inp)
    limit = self.b*inp, float(limit))
    return x

class Stats(nn.Module):
  def forward(self, x):
    return torch.tensor([x.mean(), x.std()])

class Init_Inference_Layer(nn.Module):
      def __init__(self, initialiser, stat_inf):
        self.initialiser = initialiser
        self.stat_inf = stat_inf
      def forward(self, x, a):
        data = [self.initialiser(nn.Conv2d(int(nf), int(ni), int(k)).weight, a) for (nf, ni, k), a in zip(x, a)]
        stat = [self.stat_inf(data) for data in data] 
        return torch.tensor(stat, requires_grad = True) 

class Init_Model(nn.Module):
  def __init__(self):
    self.lin = nn.Linear(3, 1)
    self.relu = nn.ReLU()
    self.initialiser = Kaiming_init()
    self.stat_inf = Stats()
    self.inf_layer = Init_Inference_Layer(self.initialiser, self.stat_inf)

  def __call__(self, x):
    self.inp = x
    a = self.relu(self.lin(x))
    outs = self.inf_layer(self.inp, a)
    return outs`

model = Init_Model()
preds = model(init_train[:1000])
loss = mae(preds, init_lbl[:1000])

for p in model.parameters():
    if hasattr(p, 'grad'):
  except Exception:

My Output:(showing None values for grads for my parameters)
    Parameter containing:
    tensor([[ 0.2282, -0.4730,  0.0291]], requires_grad=True)
    Parameter containing:
    tensor([0.4678], requires_grad=True)

My backprop is broken please help.Any help is appreciated.thanks