Thoughts on JAX?

I was wondering among all of these ecosystems e.g. S4TF, pytorch, TF-eager, where is the place of JAX? Apparently, it is a combination of autodiff and XLA but for users it is pure python. As a lover of pytorch, this seems really exciting to me. Has Jeremy or anyone else looked into it? I am curious about people’s thoughts on it.

For an exercise I implemented linear regression w/ JAX. It was a good exercise in the fundamentals. I would like to go back and explore it some more…