Swift for TensorFlow Online Walkthrough

I’m going through the Swift for TensorFlow walkthrough (https://www.tensorflow.org/swift/tutorials/model_training_walkthrough) and I was curious about how they are calculating accuracy in their training loop. It seems they update the model’s weights and then re-execute the model on the batch again and then calculate accuracy. Shouldn’t the accuracy have been calculated on the batch results when the loss was determined and not after updating the model. Seems like the final reported loss and accuracy won’t totally jive and that since you just took an update step for those particular results you’re going to have better accuracy than in reality. Am I missing something? It probably doesn’t matter since it’s just an example but I’m wondering for myself as to best practices. Thoughts?

func accuracy(predictions: Tensor<Int32>, truths: Tensor<Int32>) -> Float {
    return Tensor<Float>(predictions .== truths).mean().scalarized()
}

for epoch in 1...epochCount {
    var epochLoss: Float = 0
    var epochAccuracy: Float = 0
    var batchCount: Int = 0
    for batch in trainDataset {
        let (loss, grad) = model.valueWithGradient { (model: IrisModel) -> Tensor<Float> in
            let logits = model.applied(to: batch.features, in: trainingContext)
            return softmaxCrossEntropy(logits: logits, labels: batch.labels)
        }
        optimizer.update(&model.allDifferentiableVariables, along: grad)
        
        let logits = model.applied(to: batch.features, in: trainingContext)
        epochAccuracy += accuracy(predictions: logits.argmax(squeezingAxis: 1), truths: batch.labels)
        epochLoss += loss.scalarized()
        batchCount += 1
    }
    epochAccuracy /= Float(batchCount)
    epochLoss /= Float(batchCount)
    trainAccuracyResults.append(epochAccuracy)
    trainLossResults.append(epochLoss)
    if epoch % 50 == 0 {
        print("Epoch \(epoch): Loss: \(epochLoss), Accuracy: \(epochAccuracy)")
    }
}

Doing the forward pass twice is certainly an inefficient way to get metrics! :open_mouth:

@saeta I guess some time before the MOOC comes out it would be good to update the official tutorials to have code that you think represents best practices?..

2 Likes

I hadn’t even considered the additional overhead aspect. Duh.