JAX port of fastai as alternative to swift?

Fastai team seems to have been looking for a way to implement fastai wonderful library in a system that allows to code high-performance new algorithms.
This was a big theme in last year’s course, where the idea of porting fastai to swift was proposed.
I do understand the rationale for such a project. However I am personally very sceptic about this particular route. At the end of the day this is a concept that has been tried time and again with, IMO, poor results: moving into a compiled language. Julia, e.g., follows a similar route.
IMO, such endeavours end up being examples of “premature optimization”. I very much favor the idea of working with a mature language which is easy to code in and that already has a rich ecosystem of libraries, and then add a jit compiler that optimizes precisely the parts of the code that has more impact in performance.
Of course, Python + a good jit compiler would be a wonderful solution. I agree with the fastai staff that Pytorch jit is a hassle. A good alternative migh be numba but it doesn’t seem to work well with Pytorch and other libraries.
I think that a wonderful possibility would be to port fastai to JAX. Jax features a numba-like jit that seems very easy to use and that would optimize Python, a way to vectorize Python functions and provides automatic differentiation for Python code.
A fastai port to JAX has the potential of combining the power and ease of use of fastai and
the performance of compiled code for key functions. Plus it should be an easy port, since the language is still the same.
What are you thoughts in this? Do you think this would be such a big leap forward as I think? My experience with JAX is still limited and I may be wrong about this.
If you think it’s worthwhile, would someone be interested in working in such a project? I would be interested in investing time into this and it would also be a great learning experience for fastai students.
Link to JAX repo: https://github.com/google/jax

3 Likes

Yes it looks like JAX has a lot of potential and in fact a similar idea had been in the back of my head for several months. But there are several reasons that it may not be much better than traditional PyTorch+fastai:

  1. It’s a Google project. There’s no guarantee that JAX won’t turn out to be just like TensorFlow or lose momentum like Swift for TensorFlow.
  2. It’s still uses Python. Python as a language still suffers from some challenges which is why we have packages like fastcore that hack Python to support features that didn’t already exist. But other languages better natively support features like multiple dispatch for example.
  3. There are already plenty of JAX wrappers for deep learning

Also, JAX fundamentally works differently compared to PyTorch. If you want an example of this, I recommend you to check this post:

5 Likes

JAX is more like Swift than PyTorch, so some of the ideas from SwiftAI might port over nicely. Worth a try for anyone interested!

6 Likes

That looks interesting actually but what’s the performance like?

(Asking for a friend old enough to know that the only JIT compiler ever worth discussing was banned in a Sun vs MS lawsuit 20 years ago)

1 Like

There are some benchmarks here:

2 Likes

Jax is a fresh take on deep learning and a really cool project, but the reasoning that a hacked JIT on top of python is better than a compiled language makes no sense. You either have a general purpose compiler that produces fast code or you’ll end up with the same limitations currently present in python. Simple example: try to write fast mutating code in JAX.

It follows the same principle stated here :

Any sufficiently complicated machine learning system contains an ad-hoc, informally-specified, bug-ridden, slow implementation of half of a programming language.

Also if Julia has poor results I have no idea what can be considered good. Please go watch the JuliaCon 2020 to see how many scientists have migrated their code and simulations, and how the language is growing really fast.

I think that trying to build a fastai-like high level interface on top of JAX is worth it because it’s trying to explore a different path than what both pytorch and tensorflow 2.0 have converged to, not because of some empty reasoning that tries to discredit the hard work that went into the alternatives.

2 Likes

:laughing:
Well, numba has proven to work pretty well: https://murillogroupmsu.com/numba-versus-c/
Jax seems not to be yet on par with numba yet but it does result in a big boost and I hope it’ll get better, as numba did:
https://github.com/google/jax/issues/1421

Anyway a good automatic differentiation system (maybe based on computing graphs, as pytorch) that is also compatible with numba would be my preferred option.

2 Likes

It’s a good thing there are multiple paths & possibilities for the future of ml and scientific computing :slight_smile:

  • Jax and a lot of libs built on that (Flax, Haiku, RLax, etc)
  • Swift-based stuffs like (Swift for TF, Fast.ai on Swift)
  • Julia-based stuffs (Flux, Knet, MXnet Julia interface - FastAI.jl: a Fast.ai port to Julia)

I know some of them in progress, some of them not 100% yet, but the main thing is there is always hope for good alternatives :wink:

1 Like