I’m concerned whether the following code is right or not, as validation accuracy values are erratic and keeps jumping up and down after each epoch. Starts at 13% shoots up to 25%, sometimes 40% and then back to 13% or worse, 6%. I have 5 classes. Training accuracy on the other hand is reaching 80-90%, which is actually common for training dataset among other models, while the validation acc stays around 60-70%. Is the code incorrect somewhere?
# TRAINING SECTION
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct_predictions = 0
total_samples = 0
with tqdm(total=len(train_loader), desc=f"Epoch {epoch}/{num_epochs}", unit="batch") as pbar:
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct_predictions += (predicted == labels).sum().item()
total_samples += labels.size(0)
pbar.set_postfix(loss=loss.item())
pbar.update(1)
avg_loss = running_loss / len(train_loader)
train_accuracy = 100 * correct_predictions / total_samples
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
avg_val_loss = 0
val_accuracy = 0
# VALIDATION SUB-SECTION
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
val_correct += (predicted == labels).sum().item()
val_total += labels.size(0)
avg_val_loss = val_loss / len(val_loader)
val_accuracy = 100 * val_correct / val_total
print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")
train_results['epoch'].append(epoch)
train_results['train_loss'].append(avg_loss)
train_results['train_accuracy'].append(train_accuracy)
train_results['val_loss'].append(avg_val_loss)
train_results['val_accuracy'].append(val_accuracy)
# EARLY STOPPING
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
counter = 0
torch.save(model.state_dict(), 'best_model.pth')
else:
counter += 1
if counter >= patience:
print("Early stopping triggered")
break
# CHECKPOINT SAVING
checkpoint = 'kaggle_KNEEOA_without_pretraining adam.pth'
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss,
}, checkpoint)
print('Checkpoint {} saved'.format(checkpoint))
torch.cuda.empty_cache()