In lesson 5 we were to make our own linear function can someone help me out why my loss isnt being plotted. My parameters are nan the update function is below: Why are my parameters nan?

Here is my linear function:

```
from torch.nn.parameter import Parameter
def linear_mul(input, weights, bias):
return (input@weights + bias)
class Mnist_Logistic_mine(nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device('cuda')
self.weights = Parameter(torch.Tensor(784, 10))
self.bias = Parameter(torch.Tensor(10))
def forward(self, xb):
return linear_mul(xb, self.weights, self.bias)
```

here is my update function:

```
def update(x,y,lr):
wd = 1e-5
y_hat = model(x)
# weight decay
w2 = 0.
for p in model.parameters(): w2 += (p**2).sum()
# add to regular loss
loss = loss_func(y_hat, y) + w2*wd
loss.backward()
with torch.no_grad():
for p in model.parameters():
p.sub_(lr * p.grad)
p.grad.zero_()
return loss.item()
```

my parameters somehow have nan values. idk why:

Here is the plot: