How to create a callback using torch.multiprocessing (TPU)

I am creating a callback for distributed TPU training based on the pytorch XLA library. It will require the training loop to be passed into a multiprocessing spawning function to spawn training loops on each of the TPU cores. I have worked with the callbacks before but I am unsure if it is possible to change the training loop with a callback. Is this possible?

@sgugger any input?

1 Like

I have not experimented with pytorch XLA yes, so I don’t know the best way to handle this. The first step would be to see a PyTorch training loop that handles that training so we can assess how to best handle this. If it requires changes in the actual training loop, we will only support this feature in v2.

The sample code for usage is over here:

My current approach is:

  1. a callback to adapt the dataloader and optimizer for XLA
  2. A new function to_tpu_distributed() that redefines the fit function for TPU usage.

I am still working out the bugs, but I think I could get it to work soon.

I already have a working single TPU callback but I didn’t see any gains using Colab compared to the K80 GPU in Colab.

I haven’t actually played with fastai v2 just yet. But I will do so next week (mainly to use medical imaging module for the RSNA comp) and implement something for TPU usage

Oh it should fit in a Callback then. The steps to be done are:

  • replacing the dataloaders with ParallelLoader
  • switching everything to the right device
  • cancelling the step and doing the xml step instead

Then if you define everything in a function named main_train, your script should have

if __name__ == '__main__':
  xmp.spawn(main_train, args=())

at the end.

Hmmm I guess this would work but I was hoping to have something that could be self-contained. Like if we could just do learn.to_tpu_distributed() and and would subsequently train on the the TPU

I actually already have a callback that does the steps:

class TPUDistributed(LearnerCallback):
  def __init__(self, learn:Learner):
    self.device = xm.xla_device()
  def on_train_begin(self, **kwargs:Any)->None:
    self.learn.model = = pl.ParallelLoader(, [self.device]).per_device_loader(self.device)
  def on_step_end(self, **kwargs:Any)->None:

@sgugger ok I just found out a potential problem…

The API guide is slightly wrong.

This is accurate code for the multiprocessing situation:

Iterating through the TPU DataLoader actually returns both the index and a tuple of xb and yb.
Like this:
for x, (data, target) in loader:

So the training loop actually needs to be changed. I could look into maybe changing the method definitions of the TPU parallel dataloader but I hope it doesn’t break anything with XLA.

1 Like

I don’t think you’d need to call the spawn from an if __name__ but I think you would need to do it fairly early. Looks like xmp.spawn is relying on initialising torch.multiprocessing before anything else does to start XLA enabled multiprocessing. So I think if torch.multiprocessing had been initialised beforehand it wouldn’t work (or would end up in a weird state). I think that will happen automatically when you create a dataloader which will happen if you create a databunch. But seems like you could have a setup_xla_distrib() along the lines of the current setup_distrib().

Couldn’t you just wrap it in something that throws away the index, they don’t actually seem to use it except for logging. You might want to then somehow pass this on to some things for logging but that should be possible. You could probably just pass the CallbackHandlers state_dict to the wrapper and update at will.

I’d think it’s probably worth hacking together some quick tests of performance with fastai and with standard XLA code before spending too much time making it nice. Might also want to verify against a proper GCP TPU as the docs note there are some pretty big performance issues on colab. With a tight test script it shouldn’t cost too much (2.25c per min for 8 TPU cores).
It may be that too much of fastai assumes tensors are on GPU for much speed advantage. Looks like there’s an even heavier penalty for moving XLA tensors back to the CPU than with a GPU (TPU’s being over USB or network). So may be that fastai v1 just won’t provide good performance without extensive changes here. Things may be a little better with v2, or perhaps changes could be made. It’s likely such changes would also help performance on GPUs as I suspect that this is one of the key speed limits in fastai.

But you still need to pass the fit function to xmp.spawn.

I am not understanding. It does not seem setup_distrib() spawns process like torch.multiprocessing.spawn does. So how can we have a similar function?

You might be right, I might be able to use on_batch_begin

Again, I do have a callback for using a single TPU core and it seemed liked it worked but it was slower or equivalent to K80 when I compared both on the Colab. It could be due to some of the issues you mentioned but I hope that distributed training might help alleviate those problems.

Ah yes, setup_distrib just creates the process group handler and later uses that handler to launch actual processes. Whereas xmp.spawn does both in one step. So not sure you can (or would really want to) avoid following the approach they use where you’d run xmp.spawn with a function that does all your fastai setup.
Note that they are actually creating one dataloader per TPU core, each in a separate process, with that dataloader having multiple worker processess. So the equivalent would be creating a separate DataBunch and Learner per TPU core which you can’t then do with a learn.to_distributed() style API.

To be clear I meant wrap the torch_xla.distributed.ParallellLoader. The current distributed stuff does this in fastai.distributed.DistributedCallback from it’s on_train_begin so could borrow it’s code.

Sorry this isn’t completely clear to me…

Where is this done in distributed?

Also, it seems using on_batch_begin works. Now I am getting an error I am not sure how to fix it but I will probably work on it tomorrow.

It seemd like it was training, but the errors were mainly leaked semaphores and this:
RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 7 7, but got 0-dimensional input of size [] instead

It doesn’t actually use it’s own DataLoader wrapper, but here’s the code where it creates a new dataloader and replaces the current one.

Ah ok I see what you mean. But I think using on_batch_begin is easier when I already have a callback.

Can’t tell without seeing your code but given the semaphore and size mismatches it sounds like you might be sharing things between processes you shouldn’t. Or, for the semaphore thing, failing to initialise TPU stuff early enough.

So I got a fastai DataBunch working on TPU with the training loop from an example in this notebook.
Main issue was a weird bug caused by fastai’s monkeypatching of Tensor.__array__ which took a fair while to track down. Though it’s only because the example calls np.array on the results which triggers it. So you wouldn’t generally encounter it. Still, I might look at fixing it in fastai (it took me long enough to track it down).
Otherwise only thing I changed from standard was disabling memory pinning in the dataloader (param to .databunch passed through). Was getting some crashes before that, which haven’t reappeared, but I’ve found it a little flaky in general so could be unrelated. Still this option makes no sense for XLA and will just slow things down.

Will now look to add in other fastai bits. Decided to build it up from the bottom this way rather than start from a callback (not to dissuade you, we could likely happily meet in the middle).


The monkey-patching was there for earlier versions of PyTorch when the __array__ wasn’t there. Not sure it’s necessary anymore.

Yeah, from at least 1.0 PyTorch support __array__ (didn’t check before). The only issue is that fastai moves a CUDA tensor to CPU, the PyTorch version errors, but clearly saying move to cpu. So not sure if you’d want to maintain that. This seems to work in both as an alternative maintinign the full current behaviour (and no paricular XLA stuff):

orig_Tensor__array__ = torch.Tensor.__array__
def Tensor__array__(self, dtype=None):
    if self.device.type == 'cuda': self = self.cpu()
    return orig_Tensor__array__(self, dtype)
torch.Tensor.__array__ = Tensor__array__

It would also include any enhanced logic in PyTorch (currently it’s the same as fastai bar the .cpu(). That will fail prior to PyTorch 1.0 but fastai requirement seems to be >=1.0.

Thanks for doing this. However, it is mentioned in the API guide that “Running on multiple XLA devices using processes (see above) is preferred to using threads.” Therefore, I was planning to get the multiprocessing version working first before moving onto the threading version. The only problem with the MP version is that it has to be run as a script. So it’s still helpful to have a threading version to run in colab notebooks.

BTW, after I mentioned the error in the PAI guide, they said they might remove the enumerate-like behavior: