How to create a callback using torch.multiprocessing (TPU)

Pity they didn’t want to alter it. Though I did see it just got re-opened by someone else so they might reconsider. I did a quick check and unfortunately Lightning (another PyTorch framework) didnt’ seem to use len(loader) which would’ve bolstered the case.

The XLA:0 things sounds like maybe a configuration issue in the spawning (or similar). This should be determined internally based on what other devices are participating in the distributed cluster. I believe this is based on environment variables so they may not be getting properly passed to the child processes for some reason. You might want to print out the XLA environment variables in the parent/child processes. They are in torch_xla.core.xla_env_vars. In particular:

WORKERS = 'XRT_WORKERS'
ORDINAL = 'XRT_SHARD_ORDINAL'
WORLD_SIZE = 'XRT_SHARD_WORLD_SIZE'

No values were printed from this, just the same strings you listed above. Does that mean nothing was set yet? I don’t think this is the issue though. I will share my code if you have any insights.

@TomB Here is the code. I am done for the night, but if you find something out let me know and I will look into it tomorrow. Thanks in advance!

import torch_xla
import torch_xla.distributed.data_parallel as dp
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch

from fastai import *
from fastai.core import *
from fastai.torch_core import *
from fastai.vision import *
from fastai.basic_train import *

def len_parallelloader(self):
  print('hello length from',self._device)
  return len(self._loader._loader)
pl.PerDeviceLoader.__len__ = len_parallelloader
  

class TPUDistributed(LearnerCallback):
  def __init__(self, learn:Learner):
    super().__init__(learn)
    self.device = xm.xla_device()
    print('callback ',self.device)

  def _change_dl(self,dl, shuffle):
    old_dl = dl
    sampler = torch.utils.data.distributed.DistributedSampler(
      dl.dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=shuffle
    )
    new_dl = dl.new(shuffle=False, sampler=sampler)
    return old_dl,new_dl,sampler


  def on_train_begin(self, **kwargs:Any)->None:
    self.learn.model = self.learn.model.to(self.device)

    shuffle = self.data.train_dl.init_kwargs['shuffle'] if hasattr(self.data.train_dl, 'init_kwargs') else True
    self.old_train_dl,self.data.train_dl,self.train_sampler = self._change_dl(self.data.train_dl, shuffle)
    if hasattr(self.data, 'valid_dl') and self.data.valid_dl is not None:
      self.old_valid_dl,self.data.valid_dl,self.valid_sampler = self._change_dl(self.data.valid_dl, shuffle)
    
    self.learn.data.train_dl = pl.ParallelLoader(self.data.train_dl, [self.device]).per_device_loader(self.device)
    self.learn.data.train_dl.dataset = None #self.old_train_dl.dataset
    self.learn.data.valid_dl = pl.ParallelLoader(self.data.valid_dl, [self.device]).per_device_loader(self.device)
    self.learn.data.valid_dl.dataset = None #self.old_train_dl.dataset

  def on_batch_begin(self, last_input, last_target, train, **kwargs):
    return {'last_input': last_target[0], 'last_target': last_target[1]}
  def on_step_end(self, **kwargs:Any)->None:
    xm.optimizer_step(self.learn.opt.opt)


def _to_tpu_distributed(learn:Learner) -> Learner:
  learn.callback_fns.append(TPUDistributed)
  return learn
  

Learner.to_tpu_distributed = _to_tpu_distributed
  

path = untar_data(URLs.MNIST_SAMPLE)
def train_loop(index):
  data = ImageDataBunch.from_folder(path)
  learn = cnn_learner(data, models.resnet50, metrics=accuracy).to_tpu_distributed()
  print('hello')
  learn.fit(1)

if __name__ == "__main__":
  xmp.spawn(train_loop,args=())

Note that when I print xm.xrt_world_size() it prints 8 which makes sense.

Ah if xrt_world_size is correct then it’s fine, it’s just reading xenv.WORLD_SIZE so should match sys.environ['XRT_SHARD_WORLD_SIZE'] (or sys.environ[xenv.WORLD_SIZE]).

Shouldn’t:

be:

def on_backward_end(self, **kwargs):
    xm.optimizer_step(self.learn.opt.opt)
    return {'skip_step': True}

to avoid double stepping.

Also shouldn’t it just be self.learn.opt to step on the wrapper which handles layer (/parameter) groups not the inner optimiser (or could just be self.opt as LearnerCallback proxies attributes to self.learn). Though there might be complications there that need to be addressed.

1 Like

Sorry about this error. The on_step_end error was a dumb error I made which you pointed out earlier, but I lost my fix to it because I had two versions of the TPU callback code. I will go and check that I have the correct version with the correct fixes.

Regarding the stepping of the optimizer, I did this because I was worried that xm.optimizer_step will only work with a torch.optim.Optimizer class and not with OptimWrapper. However, both have a step method so if it is only relying on that, it should be fine.

Thanks for pointing out the skip_step. I completely forgot about that.

1 Like

@TomB Do you know what this means? I thought the data is evenly divided between the devices?:

@TomB a couple updates.

So I did something I should have done a long time ago. I ran the MNIST multiprocessing example.
I noticed two things.

  1. There is the same semaphore warning being printed out.
  2. Their code also seems to only use xla:0 which is printed out when printing out the loss, and I printed out which device is being used in the training loop and I also only get xla:0.

I think I know the fix to using all XLA devices. However, I will raise an issue to start out.

Here are the issues:

Going through the code I think it will be OK.
From what I could see looking through the torch_xla code the only things it expects are opt.step(), opt.zero_grad() and opt.param_groups. The first two should be fine, that’'s how the fastai OptimWrapper expects to be used. In terms of the last I think it’s OK but am less clear here.
The param_groups are used by _fetch_gradients. It’s actually accessed slightly oddly with optimizer.__getstate__()['param_groups'] (had to search to remind myself __getstate__ is for pickling so I guess this is allowing different internal storage as long as a common import/export format is respected). In fastai v1 the OptimWrapper.__setstate__ is overridden but not __getstate__ (there’s a getstate() that’s used). So __getattr__ delegates this to the wrapped optimisers __getstate__ and it should be fine (v2 might be different here) The purpose of this is to get a list of gradients to then do a cross replica sum before they are applied with step(). So it just needs to collect all the gradients and I think no gradients should be returned twice or else it might sum them twice.
The one possible issue I saw was if the OptimWrapper.layer_groups and the inner Optimizer.param_groups got out of sync. But this should only be an issue if new parameters are added that will then be in layer_groups but not param_groups. Looking at some (but not all) the fastai stuff here I think mostly they are just rearranging groups and removing layers that shouldn’t update. So this would be fine as all updated params would still be in param_groups and updating is all in step (you might just needlessly accumulate non-learnable params).
The one possible exception was the MixedPrecision stuff, but I think there’s special handling for this in torch_xla so the standard to_fp16 should not be used (in which case the XLA callback should probably error in on_train_begin if it finds that stuff on the learner).

So I don’t think fastai’s layer groups should interfere here. But perhaps @sgugger will quickly spot something I’ve missed.

Yeah, was a bit hard to follow the replies (or the code), but from some digging I think that the basic logic is:

batches = []
for batch in base_loader:
  batches.append(batch)
  if len(batches) == num_devices:
    process_batches(batches)
    batches.clear()

I also might me missing some details in the methods it’s calling that introduce extra subtleties (and this is only with the default fixed_batch_size=False).
So the actual length will depend on whether the number of batches is divisible by the number of devices and doesn’t rely on being able to do len(base_loader). So they’d have to introduce extra logic and assumptions to report length and so I can kinda see their reasons to not want to add this.

It looks like having ParallelLoader.PerDeviceLoader.__len__ as len(self._loader._loader) // len(self._loader._devices) should always be correct. Note though this is not the same as the length of the DistributedSampler so your current logic is not correct (as my previous assumption that the distributed sampler was doing the dividing was wrong). The distributed sampler is another level of parallelism. As noted in that issue on the xla:0 thing, on colab there is only one device, that one device having 8 cores (so you get xla:0/1 etc. In this case the multiprocessing samples don’t use a DistrbutedSampler as the creation of this is gated by if xm.xrt_world_size() > 1 but the world size (i.e. number of devices) will always be 1 on colab.

Thanks for investigating. Yeah v2 will probably need to rewrite the xla optimizer to take the changes into account since we wrote something different. Or change the optimizer to align more with PyTorch.
For v1, this seems all good. We’ll see for mixed precision in a second step, let’s get training working first.

Actually it will still use DistributedSampler, as the xm_xrt_world_size()=8 (I printed it out from each process). However, the xm.get_xla_supported_devices() only returns ['xla:0'] indicating each process is being run on just its one XLA device. I still think the DistributedSampler is correct and my length code is correct. The DistributedSampler also compensates if the number of batches is not evenly divisible.

Right now I think everything is working except I am still getting the following error (during validation):

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 7 7, but got 0-dimensional input of size [] instead

By the way, I made all the changes you mentioned above, including changing the XLA stepping to be on the OptimWrapper rather than the underlying Optimizer object. And apart from the above error (which was anyway there) it seems to work fine.

Regarding the semaphore issue, I think all we can do is wait and see if the PyTorch XLA developers can find out the cause.

Ah, OK my mistake, was misreading the printing, one device, 8 ordinals. And yeah, if world size is 8 and so the distributed sampler is being used then your previous code should be fine.

On the error, 3 channel input so I gather the first layer. So maybe the per-device batching isn’t being done properly.for some reason. You might try grabbing a batch and printing shape from the initial loader after applying the sampler and then again after the parallel loader (i.e. next(iter(self.data.train_dl)).shape). Narrow it down.

Yes, I at least know that this is an issue with validation, not training, so I will check next(iter(self.data.valid_dl)).shape). However, I am surprised there would be a problem with data.valid_dl as I do the same preprocessing (changing sampler to distributed, wrapping in ParallelLoader) as I did with the training dataloader.

I’ll look into it further. Thanks again for all your help so far! :slight_smile:

I think I found the error. I had assumed the enumerate behavior was still there, but they removed it recently in Colab. So I removed the on_batch_begin() that compensated for that.

However, I am dealing with a error that is going to be tougher to fix unfortunately. I get the following error from the validation loop:

File "/content/tpu_distributed_fastai.py", line 118, in train_loop
    learn.fit(1)
  File "/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py", line 200, in fit
    fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
  File "/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py", line 104, in fit
    if not cb_handler.skip_validate and not learn.data.empty_val:
  File "/usr/local/lib/python3.6/dist-packages/fastai/basic_data.py", line 122, in __getattr__
    def __getattr__(self,k:int)->Any: return getattr(self.train_dl, k)
AttributeError: 'PerDeviceLoader' object has no attribute 'empty_val'

Here are the problems with fixing this.

For one, empty_val is a property of learn.data and should be returned as such. However, it is instead going to this __getattr__ function which is telling to just assume it’s an attribute of self.train_dl, which I have redefined to be the PerDeviceLoader iterator. I have no idea why it isn’t using the def empty_val(self): which has a property decorator.

Even if I do fix it, there is still on more challenge, which I have a better idea of fixing. Here is the code for the empty_val attribute:

    @property
    def empty_val(self)->bool:
        if not hasattr(self, 'valid_dl') or self.valid_dl is None:            return True
        if hasattr(self.valid_ds, 'items') and len(self.valid_ds.items) == 0: return True
        return (len(self.valid_ds) == 0)  

So it’s going to look for self.valid_ds:

@property
    def valid_ds(self)->Dataset: return self._grab_dataset(self.valid_dl)

Both of the above attributes cannot be monkey-patched. Otherwise, I would try to directly change those. Here is the definition of _grab_dataset

    def _grab_dataset(self, dl:DataLoader):
        ds = dl.dl.dataset
        while hasattr(ds, 'dataset'): ds = ds.dataset
        return ds

So I think the solution here is to set in the callback self.learn.data.valid_dl.dataset = self.old_valid_dl.dataset.

If you have any ideas to solve the getattr problem, please let me know. I will probably come up with another hackish way to fix it :joy:

Ok it turns out the __getattr__ problem isn’t actually a true problem. I did some digging and it turns out if a property attribute has an error, it falls back to __getattr__ and we lose the raised exception from the property attribute. So I guess I just have to fix the one problem setting self.learn.data.valid_dl.dataset.

I guess I learned something new about Python!
Here are some references of this issue:



@sgugger you guys probably already know about this, but is this something to consider when it comes to fastai v2 codebase?

OK I fixed this and now the entire code runs with no error. I will clean up the code a little bit and share.

Oh, just noticed that PerDeviceLoader is not an iterable, it’s __iter__ just returns self. So unlike the standard iterables it can’t be iterated over multiple times. Note that the loop in their samples is like:

for e in range(NUM_EPOCHS):
    loader = pl.ParallelLoader(...).per_device_iterator(...)
    for batch in loader:
        ...

So unless I’m missing something I can’t see how the current callback code can work across multiple epochs. After the first epoch it’s just always going to immediately return a StopIteration error as it’s at the end of the underlying iterator (which is created on init of the ParallelLoader in a worker thread). Probably worth raising an issue but I guess they have reasons for this.
Also, thinking about it, outside of this it might not be such a good idea to permanently replace train_dl with a PerDeviceLoader, even ignoring the above. Iterating it (which will need to create a new PerDeviceLoader) is going to create a new background thread and start shifting batches to the TPU. So it might be better to just create the ParallelLoader for the duration of the training loop and repl;ace it with the standard one in on_train_end.

The two ways I can see to deal with the above are to either make a wrapper that properly implements an iterable, creating a new ParallelLoader and PerDeviceLoader in __iter__, or to replace them every epoch in on_epoch_begin. The later means iteration outside of the standard train loop won’t work while the former would allow this, but doing this may cause issues, not sure it’s intended to have 2 of them in play at once.

This is a good point. I never thought to try more than 1 epoch, which is dumb on my part.

Indeed creating a wrapper was something suggested by the PyTorch XLA developers over here.

Eventually I will probably create a wrapper. However, for now I am just going to move the reassigning code to a on_epoch_begin function. Just to get a quick proof-of-concept. Because I still think we need to discuss if the current interface is really the best and how we would fit this in with fastai v2.

1 Like