I’m trying to write a basic training loop that would take a Dataset
, a model, an Optimizer
and a loss function. After a bit of search, I’ve managed to get to declare the first three, but how to declare in S4TF that the last one is differentiable eludes me. My progress so far is:
func basic_training_loop<Model, Opt:Optimizer> (train_ds:Dataset<Batch>, model: inout Model, opt: inout Opt,
loss_func: (Tensor<Float>, Tensor<Float>)->Tensor<Float>)
where Opt.Model == Model, Opt.Scalar == Float,
Model.Input == Tensor<Float>,
Model.Output == Tensor<Float>
{
for batch in train_ds{
let (loss, grads) = model.valueWithGradient { model -> Tensor<Float> in
let preds = model.applied(to: batch.x, in: trainingContext)
return loss_func(preds, batch.y)
}
print(loss)
opt.update(&model.allDifferentiableVariables, along: grads)
}
}
and it complains that loss_func
isn’t differentiable (with reasons). I’ve tried adding an @differentiable in the declaration but that doesn’t help either.