# Initialisation

In the fast.ai part - 2 lectures Jeremy said that initialization has a great impact in a model’s training. But the gain ‘a’ is a parameter that can be changed…I guess…other than sqrt(5).So i thought of training a model which finds the best ‘a’ for a tensor shape…but the backward computation graph is broken any help is appreciated…here’s my code.

``````My Training Set:
init_train = []
for i in range(6, 513):
for j in range(i+2, i+515):
for k in range(3, 8):
assert i != j
lst = [i, j, k]
init_train.append(lst)

My Labels Set:
init_lbl = torch.tensor([[0, 1]]*init_train.shape[0])
``````

` My Model:`

``````class Gain(nn.Module):
def forward(self, a):
return math.sqrt((1 + a**2)/2)

class Kaiming_init(nn.Module):
def __init__(self, b = math.sqrt(3.), fan_out = False):
super().__init__()
self.gain = Gain()
self.b = b
self.fan_out = fan_out
def forward(self, x, a):
nf, ni, *_ = x.shape
rec_fs = x[0, 0].numel()
inp = nf*rec_fs if self.fan_out else ni*rec_fs
inp = self.gain(a)/math.sqrt(inp)
limit = self.b*inp
x.data.uniform_(-float(limit), float(limit))
return x

class Stats(nn.Module):
def forward(self, x):

class Init_Inference_Layer(nn.Module):
def __init__(self, initialiser, stat_inf):
super().__init__()
self.initialiser = initialiser
self.stat_inf = stat_inf
def forward(self, x, a):
data = [self.initialiser(nn.Conv2d(int(nf), int(ni), int(k)).weight, a) for (nf, ni, k), a in zip(x, a)]
stat = [self.stat_inf(data) for data in data]

class Init_Model(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(3, 1)
self.relu = nn.ReLU()
self.initialiser = Kaiming_init()
self.stat_inf = Stats()
self.inf_layer = Init_Inference_Layer(self.initialiser, self.stat_inf)

def __call__(self, x):
self.inp = x
a = self.relu(self.lin(x))
outs = self.inf_layer(self.inp, a)
return outs`

model = Init_Model()
preds = model(init_train[:1000])
loss = mae(preds, init_lbl[:1000])
loss.backward()

for p in model.parameters():
print(p)
try:
except Exception:
continue

My Output:(showing None values for grads for my parameters)
Parameter containing: