Hopefully someone with some more Swift knowledge than I have can help me out with this one. The notebook is here, if you would like more context. This code is near the bottom.
I have this fit function:
func fit<Opt: Optimizer, Labels: TensorGroup>(
epochs: Int32,
model: inout Opt.Model,
opt: inout Opt,
trainDs: Dataset<Batch<Opt.Model.Input, Labels>>,
validDs: Dataset<Batch<Opt.Model.Input, Labels>>
) where Opt.Scalar: TensorFlowFloatingPoint {
let trainContext = Context(learningPhase: .training)
let validContext = Context(learningPhase: .inference)
for epoch in 0..<epochs {
for batch in trainDs{
var (_, šmodel) = model.valueWithGradient{ model -> Tensor<Opt.Scalar> in
let preds: Tensor = model.applied(to: batch.inputs, in: trainContext) as! Tensor<Opt.Scalar>
return softmaxCrossEntropy(logits: preds, labels: batch.labels as! Tensor<Int32>)}
opt.update(&model.allDifferentiableVariables, along: šmodel)
}
var totalLoss = Tensor<Opt.Scalar>(0)
var totalAccuracy = Tensor<Opt.Scalar>(0)
var it = Opt.Scalar(0)
for batch in validDs{
let preds = model.applied(to: batch.inputs, in: validContext) as! Tensor<Opt.Scalar>
let loss = softmaxCrossEntropy(logits: preds, labels: batch.labels as! Tensor<Int32>)
let acc = accuracy(preds, batch.labels as! Tensor<Int32>)
totalLoss += loss
totalAccuracy += acc
it += Opt.Scalar(1)
}
print(epoch, totalLoss/it, totalAccuracy/it)
}
}
When I run it, clearly the modelās weights are not being updated:
fit(epochs: 4,
model: &model,
opt: &opt,
trainDs: trainDataset,
validDs: validDataset)
0 3.2903287 0.07852309
1 3.2903287 0.07852309
2 3.2903287 0.07852309
3 3.2903287 0.07852309
If I run the same code outside of a function, it works just fine.
let trainContext = Context(learningPhase: .training)
let validContext = Context(learningPhase: .inference)
for epoch in 0..<epochs {
for batch in trainDataset{
var (_, šmodel) = model.valueWithGradient{ model -> Tensor<Float> in
let preds: Tensor = model.applied(to: batch.inputs, in: trainContext)
return softmaxCrossEntropy(logits: preds, labels: batch.labels)}
opt.update(&model.allDifferentiableVariables, along: šmodel)
}
var totalLoss = Tensor<Float>(0.0)
var totalAccuracy = Tensor<Float>(0.0)
var it: Float = 0
for batch in validDataset{
let preds = model.applied(to: batch.inputs, in: validContext)
let loss = softmaxCrossEntropy(logits: preds, labels: batch.labels)
let acc = accuracy(preds, batch.labels)
totalLoss += loss
totalAccuracy += acc
it += 1.0
}
print(epoch, totalLoss/it, totalAccuracy/it)
}
0 0.17956303 0.9463575
1 0.13751046 0.95820063
2 0.119824864 0.96307725
3 0.10941707 0.96596336