I am creating a callback for distributed TPU training based on the pytorch XLA library. It will require the training loop to be passed into a multiprocessing spawning function to spawn training loops on each of the TPU cores. I have worked with the callbacks before but I am unsure if it is possible to change the training loop with a callback. Is this possible?
I have not experimented with pytorch XLA yes, so I don’t know the best way to handle this. The first step would be to see a PyTorch training loop that handles that training so we can assess how to best handle this. If it requires changes in the actual training loop, we will only support this feature in v2.
I don’t think you’d need to call the spawn from an if __name__ but I think you would need to do it fairly early. Looks like xmp.spawn is relying on initialising torch.multiprocessing before anything else does to start XLA enabled multiprocessing. So I think if torch.multiprocessing had been initialised beforehand it wouldn’t work (or would end up in a weird state). I think that will happen automatically when you create a dataloader which will happen if you create a databunch. But seems like you could have a setup_xla_distrib() along the lines of the current setup_distrib().
Couldn’t you just wrap it in something that throws away the index, they don’t actually seem to use it except for logging. You might want to then somehow pass this on to some things for logging but that should be possible. You could probably just pass the CallbackHandlers state_dict to the wrapper and update at will.
I’d think it’s probably worth hacking together some quick tests of performance with fastai and with standard XLA code before spending too much time making it nice. Might also want to verify against a proper GCP TPU as the docs note there are some pretty big performance issues on colab. With a tight test script it shouldn’t cost too much (2.25c per min for 8 TPU cores).
It may be that too much of fastai assumes tensors are on GPU for much speed advantage. Looks like there’s an even heavier penalty for moving XLA tensors back to the CPU than with a GPU (TPU’s being over USB or network). So may be that fastai v1 just won’t provide good performance without extensive changes here. Things may be a little better with v2, or perhaps changes could be made. It’s likely such changes would also help performance on GPUs as I suspect that this is one of the key speed limits in fastai.
But you still need to pass the fit function to xmp.spawn.
I am not understanding. It does not seem setup_distrib() spawns process like torch.multiprocessing.spawn does. So how can we have a similar function?
You might be right, I might be able to use on_batch_begin
Again, I do have a callback for using a single TPU core and it seemed liked it worked but it was slower or equivalent to K80 when I compared both on the Colab. It could be due to some of the issues you mentioned but I hope that distributed training might help alleviate those problems.
Ah yes, setup_distrib just creates the process group handler and later uses that handler to launch actual processes. Whereas xmp.spawn does both in one step. So not sure you can (or would really want to) avoid following the approach they use where you’d run xmp.spawn with a function that does all your fastai setup.
Note that they are actually creating one dataloader per TPU core, each in a separate process, with that dataloader having multiple worker processess. So the equivalent would be creating a separate DataBunch and Learner per TPU core which you can’t then do with a learn.to_distributed() style API.
To be clear I meant wrap the torch_xla.distributed.ParallellLoader. The current distributed stuff does this in fastai.distributed.DistributedCallback from it’s on_train_begin so could borrow it’s code.
It seemd like it was training, but the errors were mainly leaked semaphores and this: RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 7 7, but got 0-dimensional input of size  instead
Can’t tell without seeing your code but given the semaphore and size mismatches it sounds like you might be sharing things between processes you shouldn’t. Or, for the semaphore thing, failing to initialise TPU stuff early enough.
So I got a fastai DataBunch working on TPU with the training loop from an example in this notebook.
Main issue was a weird bug caused by fastai’s monkeypatching of Tensor.__array__ which took a fair while to track down. Though it’s only because the example calls np.array on the results which triggers it. So you wouldn’t generally encounter it. Still, I might look at fixing it in fastai (it took me long enough to track it down).
Otherwise only thing I changed from standard was disabling memory pinning in the dataloader (param to .databunch passed through). Was getting some crashes before that, which haven’t reappeared, but I’ve found it a little flaky in general so could be unrelated. Still this option makes no sense for XLA and will just slow things down.
Will now look to add in other fastai bits. Decided to build it up from the bottom this way rather than start from a callback (not to dissuade you, we could likely happily meet in the middle).
Yeah, from at least 1.0 PyTorch support __array__ (didn’t check before). The only issue is that fastai moves a CUDA tensor to CPU, the PyTorch version errors, but clearly saying move to cpu. So not sure if you’d want to maintain that. This seems to work in both as an alternative maintinign the full current behaviour (and no paricular XLA stuff):
Thanks for doing this. However, it is mentioned in the API guide that “Running on multiple XLA devices using processes (see above) is preferred to using threads.” Therefore, I was planning to get the multiprocessing version working first before moving onto the threading version. The only problem with the MP version is that it has to be run as a script. So it’s still helpful to have a threading version to run in colab notebooks.
BTW, after I mentioned the error in the PAI guide, they said they might remove the enumerate-like behavior: