**Swift Question** Fit function does not update model's params

Hopefully someone with some more Swift knowledge than I have can help me out with this one. The notebook is here, if you would like more context. This code is near the bottom.

I have this fit function:

func fit<Opt: Optimizer, Labels: TensorGroup>(
    epochs: Int32,
    model: inout Opt.Model,
    opt: inout Opt,
    trainDs: Dataset<Batch<Opt.Model.Input, Labels>>,
    validDs: Dataset<Batch<Opt.Model.Input, Labels>>
) where Opt.Scalar: TensorFlowFloatingPoint {
    let trainContext = Context(learningPhase: .training)
    let validContext = Context(learningPhase: .inference)

    for epoch in 0..<epochs {
        for batch in trainDs{
            var (_, 𝛁model) = model.valueWithGradient{ model -> Tensor<Opt.Scalar> in
                    let preds: Tensor = model.applied(to: batch.inputs, in: trainContext) as! Tensor<Opt.Scalar>
                    return softmaxCrossEntropy(logits: preds, labels: batch.labels as! Tensor<Int32>)}

            opt.update(&model.allDifferentiableVariables, along: 𝛁model)
        }

        var totalLoss = Tensor<Opt.Scalar>(0)
        var totalAccuracy = Tensor<Opt.Scalar>(0)
        var it = Opt.Scalar(0)
        for batch in validDs{
            let preds = model.applied(to: batch.inputs, in: validContext) as! Tensor<Opt.Scalar>
            let loss = softmaxCrossEntropy(logits: preds, labels: batch.labels as! Tensor<Int32>)
            let acc = accuracy(preds, batch.labels as! Tensor<Int32>)
            totalLoss += loss
            totalAccuracy += acc
            it += Opt.Scalar(1)
        }
        print(epoch, totalLoss/it, totalAccuracy/it)
    }
}

When I run it, clearly the model’s weights are not being updated:

fit(epochs: 4, 
    model: &model,
    opt: &opt,
    trainDs: trainDataset,
    validDs: validDataset)
0 3.2903287 0.07852309
1 3.2903287 0.07852309
2 3.2903287 0.07852309
3 3.2903287 0.07852309

If I run the same code outside of a function, it works just fine.

    let trainContext = Context(learningPhase: .training)
    let validContext = Context(learningPhase: .inference)

    for epoch in 0..<epochs {
        for batch in trainDataset{
            var (_, 𝛁model) = model.valueWithGradient{ model -> Tensor<Float> in
                    let preds: Tensor = model.applied(to: batch.inputs, in: trainContext)
                    return softmaxCrossEntropy(logits: preds, labels: batch.labels)}
            opt.update(&model.allDifferentiableVariables, along: 𝛁model)
        }

        var totalLoss = Tensor<Float>(0.0)
        var totalAccuracy = Tensor<Float>(0.0)
        var it: Float = 0
        for batch in validDataset{
            let preds = model.applied(to: batch.inputs, in: validContext)
            let loss = softmaxCrossEntropy(logits: preds, labels: batch.labels)
            let acc = accuracy(preds, batch.labels)
            totalLoss += loss
            totalAccuracy += acc
            it += 1.0
        }
        print(epoch, totalLoss/it, totalAccuracy/it)
    }
0 0.17956303 0.9463575
1 0.13751046 0.95820063
2 0.119824864 0.96307725
3 0.10941707 0.96596336

Hi @metachi, I tried to run your notebook but it doesn’t work with the latest toolchain I installed - there were some API changes between Thursday and Friday. I tried to adapt it but there were quite a few changes required. I don’t see anything that stands out as obviously wrong, the only differences seem to be the use of Opt.Scalar instead of Float and the function call itself. I would try to isolate the problem by printing logs, moving the accuracy calculation outside the function, and things like that. Sorry I can’t help.

Thanks for taking a look. I updated to the latest nightly and unsurprisingly am still seeing the weird behavior. The gradients are all calculated as zero inside the fit function, but the same code outside, calculates non-zero gradients. I’ll keep debugging it.

I suspect it might be the as! Tensor<Opt.Scalar>.

AutoDiff probably does not understand that as! returns a value depending on its input, so AutoDiff thinks that the derivative is zero. Ideally, it would print out an error or warning, but it seems like it’s not doing that. I filed a bug: https://bugs.swift.org/browse/TF-455

Anyways, I think you can add another condition to fit's where clause: Opt.Model.Output == Tensor<Opt.Scalar> (conditions are comma separated). This should make it possible to remove the as!, which should deconfuse AutoDiff.

3 Likes

Thanks @marcrasi! That was definitely it! I’ll know to be suspicious of things like that if a see gradients unexpectedly calculated as 0 again.