I have been banging my head against a wall for a few days on this now. I finished chapter 5 and decided to go back and have a crack at the full MNIST from scratch problem, and implement softmax and NLL manually to help my understanding, however i’ve hit a few snags.
First of all when i initialize the parameters with an std
of 1.0
using this code:
def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()
w1 = init_params((28*28,30))
b1 = init_params(30)
w2 = init_params((30,10))
b2 = init_params(10)
params = w1,b1,w2,b2
and then run this:
def train_epoch(model, lr, params):
for xb,yb in train_dl:
calc_grad(xb,yb, model)
for p in params:
print(p)
opt.step()
opt.zero_grad()
The second params
which is b1
returns:
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], requires_grad=True)
and i get
<ipython-input-79-20ce620bb67a> in step(self, *args, **kwargs)
22 def step(self, *args, **kwargs):
23 print()
---> 24 for p in self.params: p.data -= p.grad.data * self.lr
25
26 def zero_grad(self, *args, **kwargs):
AttributeError: 'NoneType' object has no attribute 'data'
If i change std
to 0.5
i get slightly different behavior:
def init_params(size, std=.5): return (torch.randn(size)*std).requires_grad_()
w1 = init_params((28*28,30))
b1 = init_params(30)
w2 = init_params((30,10))
b2 = init_params(10)
params = w1,b1,w2,b2
def train_epoch(model, lr, params):
for xb,yb in train_dl:
calc_grad(xb,yb, model)
for p in params:
print(p)
opt.step()
opt.zero_grad()
when i call train_epoch
with the lower std
value the b1
tensor appears ok:
tensor([-0.7091, -0.2209, -0.4982, -0.1451, -0.0528, -0.0409, -0.3094, 0.3262, 0.0605, 1.4114, 0.5296, -0.1523, -0.1857, -0.2357, -0.0586, -0.0773, -0.0062, -0.1788, 0.5182, -0.0185, -0.4281,
-0.1036, 0.2513, 0.3971, -0.2471, -0.0673, 0.4250, 0.0252, -0.9616, -0.1530], requires_grad=True)
but i still get the same error afterward:
AttributeError: 'NoneType' object has no attribute 'data'
i think perhaps ive made a dogs breakfast out of my code but ive been banging my head against the wall so long my head is spinning. Here’s my code: