It works!!! Pushed a working version (still on the nasty-hack branch).
I think the main issue was that exp
quickly overflows:
>>>for dtype in [np.float16, np.float32, np.float64]:
... max = np.nonzero(np.exp(np.arange(0, 1000, 10, dtype=dtype)) == np.inf)[0][0]*10
... print(F"{dtype.__name__}: exp({max}) == inf")
float16: exp(20) == inf
float32: exp(90) == inf
float64: exp(710) == inf
It now uses the tricks in PyTorch’s softplus to maintatin stability, derivative as 1-exp(-Y)
and just using the input when above a threshold (20 by default). There’s a Derivatives notebook in extras that goes through it if people with more maths want to verify.
Passes all tests except the second derivative one which still fails, I’m thinking I’m not following something needed for that to pass, the PyTorch op autograd function version passes.
A stability check that failed on the first iteration or two before now does 1000 iterations fine (random each iteration).
And as a bonus the performance is MUCH better:
relu_fwd: 248.6µs ± 1.617µs (234.4µs - 251.1µs)
relu_bwd: 421.8µs ± 44.19µs (416.2µs - 861.5µs)
softplus_fwd: 305.4µs ± 26.03µs (254.5µs - 321.4µs)
softplus_bwd: 427.0µs ± 4.278µs (419.1µs - 433.9µs)
mish_pt_fwd: 795.8µs ± 1.882µs (780.4µs - 801.0µs)
mish_pt_bwd: 1.691ms ± 808.9ns (1.689ms - 1.692ms)
mish_cuda_fwd: 281.2µs ± 2.849µs (260.0µs - 292.4µs)
mish_cuda_bwd: 494.3µs ± 1.470µs (491.4µs - 497.4µs)
Real-world performance on the network I was using before (7 layers of conv/actn with more features and less aggressive strides than typical to emphasise actn performance) is the same as RelU (just going on epoch time). And, while I wasn’t at all designing that network for accuracy (notably no BN), it topped that too, one run, no SD etc so definitely more tests needed, but final topk:
Relu: 91.92%
Mish PyTorch: 94.40%
Mish CUDA: 95.03%
That’s without any optimisation. Though it’s likely largely bound by memory access so there may not be much scope for that, but we’ll see.
Challenge accepted
Musing on optimisation:
Currently on backward I recalculate the forward from input. Many ops in PyTorch calculate based on output, that may be better. Or may be better to stash an intermediate, and calculate out from that. One issue I have currently there is I can’t see how to numerically stabilise the exp(inp)+1
that pops up in possible simplifications. I currently just use the stable 1-exp(-Softplus(inp))
trick from PyTorch, but avoiding calculating softplus again might help performance, e.g. taking that as my intermediate to calculate out from, but then don’t know how to get exp(inp)+1
stably and maintaining derivative. Maybe someone who actually knows math can help here.