Fastai on TPU?

Recently, it was demonstrated how to use PyTorch on TPU over here

It seems there are three major changes to a typical PyTorch script:

  1. The model needs to be put onto the TPU
  2. The LR needs to be scaled
  3. The optimizer needs to be stepped using the XLA package.

The first two things seems to be easy to do with a callback. How would I be able to do the last one with a callback?


The last can also be done with a callback since you can skip the normal step in the training loop by returning the proper flag and do it in the Callback instead,


@ilovescience do you have an example of where you implemented this? Or find some publicly available code on this?

With regards to PyTorch on TPU? Or for fastai?

@ilovescience TPU for fastai.

This is a thread for fastai v1:

Development for fastai v2 is slow but ongoing.


I see that there is a device parameter in the learning should we implement the TPU code as a PR request for fastai2 instead of doing it in a callback?

In order to properly use all 8 cores of the TPU, you need to use the multiprocessing API which PyTorch XLA provides. It uses their own DataLoader for putting the data on the 8 cores of the TPU, and their own optimizer step to sync everything from the 8 cores during training. Unfortunately, it is not as simple as just setting the device parameter.