Haven’t used S4TF at all, and not very familiar with the maths, so could be off here, but I have written a couple of custom backwards kernels so some familiarity with the mechanics.
Yeah, you need a scalar for backwards, I’ve use stuff like:
x = torch.randn(10) # Input
y = my_forward(x)
loss = y.mean()
grad_out, = torch.autograd.grad(l, y, retain_graph=True) # retain_graph needed to be able to call grad again
# grad_out is now the gradient of loss w.r.t output of my_forward (y)
# my_backward calculates gradient from input and grad_out, so I pass x
grad_inp, = my_backward(grad_out, x)
assert grad_inp = torch.autograd.grad(l, x)
I think that you should find:
grad_inp == y.backward(torch.ones_like(y)) / grad_out
Not entirely across the details, code largely copied from elsewhere, but seems to work (though the above may not as that’s just adapting from this code which is a little harder to follow).
Doesn’t help with how to do it in S4TF but might at least help understand what’s going on on the PyTorch side to let you just do
l.backward() or .
y.backward(torch.ones_like(y)). This post also looks to have some nice details on how
y.backward(...) works in PyTorch.
I gather part of your issue is that whatever you’re using in S4Tf behaves more like
.backward(). I frequently encountered the only scalars issue with