Fastai v2 TPU support

fastai v2 TPU support development thread

This is a thread documenting my efforts adding TPU support to fastai v2. This GitHub repository will be updated with the necessary code.


Sometime in October, I had discovered the existence of PyTorch XLA (even before the public announcement at PyTorch DevCon 2019). Since then, I had been working on trying to add fastai v1 TPU support. See here for original discussion. Originally, I had decided to work on fastai v1 first and then move to fastai v2. I documented my efforts working on fastai v1 over here. While I successfully developed code for single-core and multi-core TPU training with fastai v1, it was much slower than expected and not more efficient than a multi-GPU setup. I obtained a lot of help from @TomB, @sgugger, and people from the PyTorch XLA team.

After a while, I got busy with classes and research. At this point I had decided to switch to fastai v2, since it was becoming much more popular and since everybody was likely going to migrate over anyway. Thankfully, much of the code was transferrable. However, I ran into some issues due to some changes in the PyTorch XLA API and changes between fastai v1 and fastai v2. If I remember correctly, the next thing I had to do is create a new type of DataLoader (similar to DistributedDL) that is compatible with PyTorch XLA. The last time I was able to work on this was in April, since I was busy with classes, research, and more.

I had some discussions with @TomB, which unfortunately we kept private since we weren’t sure about the interest of the community in such discussion and since Jeremy and Sylvain were busy with other work. But now, the community has showed much more interest (ex: some discussion here and recent discussion in Discord channel), I figured I will keep the discussion open again and document my efforts, as well as get help from the community and maybe discuss the best route (ex: a complicated callback vs. a different training loop) to include TPU support in fastai v2.

I look forward to working with the community in adding TPU support to fastai v2, in order to make it one of the very few deep learning libraries with such capabilities!

NOTE: I will add later today or tomorrow details about the kinds of tasks that are needed and what are the next steps.


Hi there, Im collaborating with @butchland in fastai_xla_extensions which originated from the invitation Global pytorch hackatoh and we get to know each other from the SF Study group. After a lot of trial and error about how to do the optimizer step we have found that doing an optimizer that just do the required step makes it work on TPU. if you like the next week we can (g)meet, even now we can use some like the discord channel if needed and allowed if you did like to change “notes” on the different approaches.

The other things, are mostly just things that need to be done. But as jeremy once said some like “it should be easy”.

Currently it works on single TPU, but we have found some “problems” or slow parts that is run on TPU, so if anybody out there reading knows about TPUs and can link some optimization documents, or how to track specific performance issues on TPU, it would be great. So later we can start with distributed trainning.

We havent yet asked people making fastai2 for help, but hopefully now that we have more attention and jeremy is back we can start to ask a lot of things.


Thank you for sharing. @butchland also mentioned the project in the discord channel. I reviewed your code and it seems like it is only for a single core while I am currently working on multiple cores. Additionally, I would like to point out that the desired approach is to keep everything in a callback unless it’s truly necessary to monkey-patch something or change the existing classes/functionality. That is how multi-GPU training and mixed precision training is implemented, and hopefully most of TPU training will be that way as well. See this, my fastai v1 single-core implementation, for inspiration…

Let me know if you have any questions or ideas!

1 Like

So I started looking into stuff yesterday and today and found out a couple things.

First, there was a change in PyTorch 1.6 that would break fastai2, as discussed here. I think I just need to somehow add a generator attribute to DataLoader and _FakeLoader but I have to investigate this further.

In the meantime, I just decided to use torch-xla v1.5 (as opposed to the nightly version, which seems to require PyTorch 1.6) which is missing some features but the core features are there. I had a minor bug where the DataLoader is not put on the TPU device. So that was pretty easy to fix. But I cannot find where in the code fastai2 puts the DataLoader on the GPU if present.

Next, I discovered a problem with pickling. See the below error:

2020-07-26 01:20:48.054366: E    2888 tensorflow/compiler/xla/xla_client/] XLA tensors do not have 
Exception in thread Thread-2:                                                                                      
Traceback (most recent call last):                                                                                 
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/", line 916, in _bootstrap_inner                                                                                                                   
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/", line 864, in run                                
    self._target(*self._args, **self._kwargs)                                                                      
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/distributed/", line $
65, in _worker                                                                                                     
    batch = xm.send_cpu_data_to_device(batch, device)                                                              
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/core/", line 518, in send_$
    return ToXlaTensorArena(convert_fn, select_fn).transform(data)                                                 
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/core/", line 291, in trans$
    return self._replace_tensors(inputs)                                                                           
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/core/", line 285, in _repl$
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/utils/", line 167, in for_each$
    return _for_each_instance_rewrite(value, select_fn, fn, rwmap)                                                 
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/utils/", line 153, in _for_eac$
    result.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap))                                             
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/utils/", line 153, in _for_eac$
    result.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap))                                             
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/utils/", line 155, in _for_eac$
    result = copy.copy(value)                                                                                      
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/", line 96, in copy                                     
    rv = reductor(4)                                                                                               
  File "/home/tmabraham/fastai2/fastai2/", line 252, in __reduce_ex__
    args = (type(self),, self.storage_offset(), tuple(self.size()), self.stride())
  File "/home/tmabraham/fastai2/fastai2/", line 272, in _f
    res = getattr(super(TensorBase, self), fn)(*args, **kwargs)
RuntimeError: torch_xla/csrc/tensor_impl.cpp:142 : XLA tensors do not have storage

For multiprocessing, the training loop function needs to be pickled, and TensorBase implements the appropriate pickling function __reduce_ex__. However, it passes the argument, but XLA tensors do not have storage. Looking through the PyTorch Tensor code (TensorBase code is similar), you can see here that there’s separate pickling functionality for XLA tensors. It looks like this might be needed for proper TPU functionality?

So I will likely have to make such changes to the fastai2 codebase, but I have not contributed to fastai2 before. I have contributed to fastai, following the git guide. Given the nbdev approach to fastai2, what are the major differences in library development? I assume I again clone the repository and make changes in a different branch, but in the notebooks?

I will try to make the necessary changes tomorrow…

Hi @ilovescience,

Yes we’re currently focused on single TPU core and trying to see where the bottlenecks are before implementing multiple cores.

Thanks for your suggestions to reduce the amount of monkey-patching – I’ve since updated it to now use callbacks as well.

As for some monkey-patching, I had to do some of that (especially on getting a default_device to return a TPU) because I don’t think Sylvain or Jeremy initially considered an environment where if a GPU is not available, that the default would be anything other than a CPU… (which would be the case where a TPU was available)

We’ll probably have to provide a PR in the fastai2 codebase to handle this.

In any case, our goal is to make it so that using a TPU on fastai would require minimal changes to your existing fastai notebooks or code.

Best regards and keep us updated on your work with multiple TPU cores!


cc: @tyoc213

This is a very exciting thread thanks for creating it!

I personally experienced a lot speed improvements ~ x10 - x20 using TPUs in several Kaggle competitions so far (8 cores) . So it’s definitely a must have in our native fastai2 code. I believe it will be more widely used by the community, even suppressing GPU usage if it has an easy to use interface similar to to_distributed().

Would it make sense to systematically tackle this problem and perhaps divide the workload? I would be more than happy to help. I previously attempted to created a similar Learner class in here for fastai-v1 to work with mutlicore TPU. A callback is definitely much better as noted in this thread.

I think it would be more important to get multicore working since most TPU devices offered publicly (Kaggle, Colab) are of that kind and it would allow us to use TPUs for the main reason - speed.

You seem to be far ahead in terms of exploration done so far, so please let me know if there are any areas that I can help with.

1 Like

Indeed TPUs are little monsters but we have found some caveats about performance in particular places. We have tried to keep track of what we are testing in nbs, for example we used a callback doing only the required optimizer step and it allowed it to run with nothing more, but we left that behind because we didnt understand much at the moment (still dont :slight_smile: ).

I meet with butch on the week, perhaps at about 8 to 10 CT or so, maybe we will share the link for anyone who wants to hang out (and know with what we are stuck), and now that the discord server has respawned we can enter a channel and just talk.

You can fork the 2 repos, you will get it because you already know how to do it on your own. And yes, we can use some help :slight_smile:.

1 Like

@kcturgutlu @tyoc213 @butchland

If we want to do this, I think it would be best to discuss with @jeremy and potentially even the PyTorch XLA team. I would be happy to lead such discussions.

Also tagging @TomB who had been involved with preliminary work and had demonstrated great expertise in the field so would love to have him join our discussions if he’s available.

1 Like

You’ll probably have to ask Jeremy about that. Personally, I don’t think this is something that’s strictly necessary, but it’s likely a decision that Jeremy will need to make about whether or not to include TPUs as a default. But I guess I didn’t really have a problem with that and instead was talking about the separate optimizer classes you created, which is not necessary.

Yep, my goal the is same! :slight_smile:

I see you are using my kernel (developed with the help of the PyTorch XLA team and @abhi1thakur) :wink: .
Which version of the kernel is the working one? The latest one is just a quick save, and while there is an older working one, I am not sure if that’s your final fastai version or if there’s more to it?

Exactly my thoughts! I have tried single-core TPU training with very little benefit. Hence, I have been focusing on multi-core TPU training.

Anyway, I will work on it more today and keep you guys updated in this thread!

Yeah right thanks a lot, sorry for not crediting you guys here lol :smile: Cool @abhi1thakur is also here! My late thanks to you for great kernels on many recent competition for TPU. Let me edit my post to add the working version. I really liked @abhi1thakur’s approaches on how to use either multicore for speeding a single experiment or running multiple parallel single core experiments, e.g different hyperparams or paralle cross validation. Kaggle is a great place for learning TPUs IMO. I agree on moving forward with the help of @jeremy and Pytorch XLA team, at least with their guidance not if full support.

1 Like

TPU support is my first priority after getting fastai2 and course-v4 out the door. I haven’t looked at it at all yet. Goal would be to try to have it working without changing the training loop if possible - i.e. make it a callback.


That sounds great!

I guess the best approach @butchland @tyoc213 @kcturgutlu is that we could work separately, and when fastai2/course-v4 is released, we could meet up and discuss design decisions and systematically approach the remaining tasks? Since @butchland and @tyoc213 are already working on single-core TPU while I am working on multi-core TPU, we could keep working on this for the next couple of weeks till the fastai2 release.

I agree this is optimal, but because of the multiprocessing approach to TPU training, it requires the training loop to be spawned 8 times on each core. There may be an approach to change the fit function under the hood without much change to the code by the user. But the current approach I am working on is as follows:

def train_loop(index):
    train_df = ...
    food = DataBlock(...)
    dls = food.dataloaders()
    learn = cnn_learner(dls, model, metrics).to_tpu_distributed() #adds the TPU callback
if __name__ = "__main__":

The full code is here. How we can make this even easier for the user is an example of what we may need to discuss.


I realize another thing that needs to be discussed and fixed later down the line is the progress bars, which are repeated 8 times for each process. Anyway I’ll look into that once I progress further.

I fixed some pickling problems. I monkey-patched TensorBase and Optimizer to be pickled correctly and be accessed correctly by PyTorch XLA. Now I have some error 4 batches into training, for which I raised an issue in PyTorch XLA repository (since many of the errors are not very clear :man_facepalming:).


I see @butchland and @tyoc213 discovered an issue with batch transforms on the TPU:

Keep us updated on the progress of this issue!

1 Like

TPU wiki (at Jeremy’s suggestion):


I just wanted to document a proposed workaround the problem of the slow batch transforms while waiting for the Pytorch-XLA team to find a solution on the affine grid sample calls (affine_grid_generator and grid_sample2d) which if @tyoc213’s interpretation of the debug is correct - generates an aten which I believes means its transferred into cpu, executed in the cpu and transferred from the cpu back to the tpu…

The idea is that if the dataloader is running on a tpu, the dataloader should execute all the batch transforms on the CPU and move it to the TPU afterwards…

This is much faster than the current process (since my performance profiling shows that running the batch transforms on TPUs is even slower than running the batch transforms on the CPU – most probably because of the aten calls.

I’ve made a github enhancement issue to track this implementation (in case anyone is interested)

1 Like

Just wanted to let people here know that we are still on that, we now can develop on our own computers which make debugin some things more easy


Little update of something missing

1 Like

I think this is the right link? :

1 Like