Chapter 4 MNIST full: accuracy not improving

Hey, I’m trying to do further research of chapter 4 where you have to train model full dataset. I’m using cross entropy loss to calculate the loss then using the loss to calculate the gradient but my accuracy never increases or decreases. here is my code:

``````criterion = nn.CrossEntropyLoss()
class Optimizer:
def __init__(self,init_params,learning_rate):
self.params, self.learning_rate = init_params,learning_rate

def calculate_gradient(self, training_set, labels, model):
predictions = model(training_set)
loss = criterion(predictions,labels)
loss.backward()

def batch_accuracy(self,pred, actual):
digit_pred = pred.max(dim=1)[1]
return (digit_pred==actual).float().mean()

def step(self):
for p in self.params:
p.data -= p.grad.data * self.learning_rate

def validate_epoch(self, model,test_dl):
accuracy = []
accuracy = [self.batch_accuracy(model(digits), label) for digits,label in test_dl ]
return round(torch.stack(accuracy).mean().item(), 4)

simple_net = nn.Sequential(
nn.Linear(28*28,30),
nn.ReLU(),
nn.Linear(30,10)
)
opt = Optimizer(simple_net.parameters(),1e-5)

def train_model(model,epoch):
for _ in range(epoch):
train_epoch(model)
print(f'Accuracy:{opt.validate_epoch(model,test_dl)}')

def train_epoch(model):
for x,y in training_dl:
opt.step()
train_model(simple_net,40)
``````

and here is the link to my full notebook: Google Colab

the problem turns out to be because of these two codes

``````opt = Optimizer(simple_net.parameters(),1e-5)
### in the opt class
self.params, self.learning_rate = init_params,learning_rate

``````

the value of `simple_net.paramers()` is a generator that produces the parameters however when I loop through the parameters I can only get or consume the parameters once and after that they’re gone. so in

``````    def step(self):
for p in self.params:
p.data -= p.grad.data * self.learning_rate
I am actually calculating the loss once and then that’s it which is why I was seeing no change in the accuracy at all. casting the generator into a list like this `list(init_parametrs)` made sure i had a reference i could loop over as many times i want