Need help speeding up batch transforms on TPU

Hi all,

We (@butchland and @tyoc213) are building an extension library to enable fastai to run on TPUs.

We encountered this weird bug where the batch transforms seem to run slower on a single TPU core compared to a GPU (which we kind of expected) but we also found out that it runs even slower than the CPU! :face_with_raised_eyebrow:

Here’s some notebooks showing the results for a single transform (Flip)

GPU (fastest) - avg time: 0.021 secs
CPU (middle) - avg time: 1.227 secs
TPU (slowest) - avg time: 7.341 secs

For the torch.nn.functional F.grid_sample method, times:
GPU - avg time: 0.000 *not measurable in time.time() diff
CPU - avg time: 0.821 secs
TPU - avg time: 4.247 secs

This is not even using gradients, just pure parallel tensor computations…

If you have suggestions to speed it up (e.g. alternative algos for batch transforms
for augmentations), we’d appreciate it!

The notebooks are here, if you want to validate it (runnable on Colab):

1 Like