How to create a callback using torch.multiprocessing (TPU)

I am unsure if it is running on CPU. It does take longer or sometimes about the same time on the TPU than on the GPU. This was true even while increasing model and dataset complexity (larger ResNet models, MNIST, CIFAR10, CIFAR100).

If you do find an error, please let me know!

For GPU or CPU?

I do not see this option for both cases though. Probably because I use Colab so much :joy:
I am 90% sure there is a deep learning algorithm to make certain resources available! I haven’t gotten access to a Tesla T4 for like 2 weeks now :frowning:

Also, for monkey patching __len__ it should be just fine to return the total length evenly distributed over all the TPUs, right?

Also, for monkey patching dl.dataset it makes to just return the entire dataset I think.

Though these should be added to the XLA library and I will raise an issue tomorrow about these methods/attributes for ParallelLoader.

That was a TPU instance. And yeah, wouldn’t be surprised if they limit, I presume like Kaggle they get a lot of use from a subset of users.

OK, not sure why I got those crashes before, just ran a cnn_learner fine, with otherwise everything from the cifar10 colab sample (so PyTorch in-memory dataset and 8-core threading). Though it didn’t work properly. First epoch OK:[xla:7](0) Loss=2.47026 Rate=32.41 GlobalRate=32.41
But then: [xla:2](40) Loss=nan Rate=295.93 GlobalRate=312.05
And once it got the nan it killed the model. Looks like it got a nan loss on first epoch and then the step kills the whole model (familiar with this from making the CUDA version of Mish).
That was using:

class ResNet18(nn.Module):
    def __init__(self, nc:int=10):
        super().__init__()
        self.resnet = create_cnn_model(base_arch=models.resnet18, pretrained=False, nc=nc)

    def forward(self, x):
        x = self.resnet(x)
        return F.log_softmax(x, dim=1)

as they had a softmax at the end of the model. So not sure what happened there. I think the softmax should mean the output ranges are OK, and shapes should both be (batch_size, n_classes). Only thing I can think of is fastai might be using some layers that aren’t stable on TPU. No idea how you debug a TPU model, whether you can even hook. Though if it’s just because of the weird mixing then not a big issue.

In terms of performance Completed 20 epochs in 455.39s, 22.77s/epoch.
Notebook is here.

EDIT: Hmm, changing from from fastai.vision import * to just importing cnn_learner and models seemed to make it better, maybe luck but I’d tried a couple of times and it went to nan on second epoch (so one step). Got through a few epochs this time, with loss rapidly increasing until it overflowed to nan.

ParallelLoader should probably be len(self._loader), so not divided by devices while PerDeviceLoader would be len(self._loader)/len(self._loader._devices) (though probably self._loader.per_device_len or some such). The PerDeviceLoader should be what the Learner uses.
But yeah, probably a good chance of getting them included upstream, unless there was some reason for this, which I can’t see.

Yeah, think this should be OK generally. Tensorflow 2.0 has an Eager mode mirroring the PyTorch way (though not sure if this is used for TPU or just GPU). This isn’t actually as different from PyTorch as it appears because in PyTorch while it’s eager, the results still aren’t immediate for GPU, they get started immediately but are basically just added to a queue (in the GPU drivers/runtime). So this is more an issue the torch_xla stuff has to deal with.

I’d think the main thing is potential performance issues because it takes longer to get results for TPU than GPU. So the pause of calling .cpu() on a GPU tensor may be less than on a TPU tensor and the access patterns in fastai may be particularly non-optimal for TPU. In fact I have a bit of a hunch the pretty big performance gap between fastai and PyTorch training is due to fastai accessing results too quickly. In particular after calculating batch losses fastai immediately accesses the item for smoothing and callbacks making everything block until the forward is complete (and probably similar with gradients). I plan to have a look at this though hard to know what can be done (and almost certainly at best a v2 thing, if not v3 or some).

1 Like

Turns out my current monkey-patched code will now work fine because it takes the length from the sampler, which has now been assigned to the distributed sampler.

Also, I raised an issue

@TomB
They did not want to fix this issue unfortunately. I will keep the monkey-patched solution for now.

Interestingly, I have noticed that if I print out which XLA device that is being used, I always get xla:0 so there’s probably still a bug even though I added DistributedSampler

Pity they didn’t want to alter it. Though I did see it just got re-opened by someone else so they might reconsider. I did a quick check and unfortunately Lightning (another PyTorch framework) didnt’ seem to use len(loader) which would’ve bolstered the case.

The XLA:0 things sounds like maybe a configuration issue in the spawning (or similar). This should be determined internally based on what other devices are participating in the distributed cluster. I believe this is based on environment variables so they may not be getting properly passed to the child processes for some reason. You might want to print out the XLA environment variables in the parent/child processes. They are in torch_xla.core.xla_env_vars. In particular:

WORKERS = 'XRT_WORKERS'
ORDINAL = 'XRT_SHARD_ORDINAL'
WORLD_SIZE = 'XRT_SHARD_WORLD_SIZE'

No values were printed from this, just the same strings you listed above. Does that mean nothing was set yet? I don’t think this is the issue though. I will share my code if you have any insights.

@TomB Here is the code. I am done for the night, but if you find something out let me know and I will look into it tomorrow. Thanks in advance!

import torch_xla
import torch_xla.distributed.data_parallel as dp
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch

from fastai import *
from fastai.core import *
from fastai.torch_core import *
from fastai.vision import *
from fastai.basic_train import *

def len_parallelloader(self):
  print('hello length from',self._device)
  return len(self._loader._loader)
pl.PerDeviceLoader.__len__ = len_parallelloader
  

class TPUDistributed(LearnerCallback):
  def __init__(self, learn:Learner):
    super().__init__(learn)
    self.device = xm.xla_device()
    print('callback ',self.device)

  def _change_dl(self,dl, shuffle):
    old_dl = dl
    sampler = torch.utils.data.distributed.DistributedSampler(
      dl.dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=shuffle
    )
    new_dl = dl.new(shuffle=False, sampler=sampler)
    return old_dl,new_dl,sampler


  def on_train_begin(self, **kwargs:Any)->None:
    self.learn.model = self.learn.model.to(self.device)

    shuffle = self.data.train_dl.init_kwargs['shuffle'] if hasattr(self.data.train_dl, 'init_kwargs') else True
    self.old_train_dl,self.data.train_dl,self.train_sampler = self._change_dl(self.data.train_dl, shuffle)
    if hasattr(self.data, 'valid_dl') and self.data.valid_dl is not None:
      self.old_valid_dl,self.data.valid_dl,self.valid_sampler = self._change_dl(self.data.valid_dl, shuffle)
    
    self.learn.data.train_dl = pl.ParallelLoader(self.data.train_dl, [self.device]).per_device_loader(self.device)
    self.learn.data.train_dl.dataset = None #self.old_train_dl.dataset
    self.learn.data.valid_dl = pl.ParallelLoader(self.data.valid_dl, [self.device]).per_device_loader(self.device)
    self.learn.data.valid_dl.dataset = None #self.old_train_dl.dataset

  def on_batch_begin(self, last_input, last_target, train, **kwargs):
    return {'last_input': last_target[0], 'last_target': last_target[1]}
  def on_step_end(self, **kwargs:Any)->None:
    xm.optimizer_step(self.learn.opt.opt)


def _to_tpu_distributed(learn:Learner) -> Learner:
  learn.callback_fns.append(TPUDistributed)
  return learn
  

Learner.to_tpu_distributed = _to_tpu_distributed
  

path = untar_data(URLs.MNIST_SAMPLE)
def train_loop(index):
  data = ImageDataBunch.from_folder(path)
  learn = cnn_learner(data, models.resnet50, metrics=accuracy).to_tpu_distributed()
  print('hello')
  learn.fit(1)

if __name__ == "__main__":
  xmp.spawn(train_loop,args=())

Note that when I print xm.xrt_world_size() it prints 8 which makes sense.

Ah if xrt_world_size is correct then it’s fine, it’s just reading xenv.WORLD_SIZE so should match sys.environ['XRT_SHARD_WORLD_SIZE'] (or sys.environ[xenv.WORLD_SIZE]).

Shouldn’t:

be:

def on_backward_end(self, **kwargs):
    xm.optimizer_step(self.learn.opt.opt)
    return {'skip_step': True}

to avoid double stepping.

Also shouldn’t it just be self.learn.opt to step on the wrapper which handles layer (/parameter) groups not the inner optimiser (or could just be self.opt as LearnerCallback proxies attributes to self.learn). Though there might be complications there that need to be addressed.

1 Like

Sorry about this error. The on_step_end error was a dumb error I made which you pointed out earlier, but I lost my fix to it because I had two versions of the TPU callback code. I will go and check that I have the correct version with the correct fixes.

Regarding the stepping of the optimizer, I did this because I was worried that xm.optimizer_step will only work with a torch.optim.Optimizer class and not with OptimWrapper. However, both have a step method so if it is only relying on that, it should be fine.

Thanks for pointing out the skip_step. I completely forgot about that.

1 Like

@TomB Do you know what this means? I thought the data is evenly divided between the devices?:

@TomB a couple updates.

So I did something I should have done a long time ago. I ran the MNIST multiprocessing example.
I noticed two things.

  1. There is the same semaphore warning being printed out.
  2. Their code also seems to only use xla:0 which is printed out when printing out the loss, and I printed out which device is being used in the training loop and I also only get xla:0.

I think I know the fix to using all XLA devices. However, I will raise an issue to start out.

Here are the issues:

Going through the code I think it will be OK.
From what I could see looking through the torch_xla code the only things it expects are opt.step(), opt.zero_grad() and opt.param_groups. The first two should be fine, that’'s how the fastai OptimWrapper expects to be used. In terms of the last I think it’s OK but am less clear here.
The param_groups are used by _fetch_gradients. It’s actually accessed slightly oddly with optimizer.__getstate__()['param_groups'] (had to search to remind myself __getstate__ is for pickling so I guess this is allowing different internal storage as long as a common import/export format is respected). In fastai v1 the OptimWrapper.__setstate__ is overridden but not __getstate__ (there’s a getstate() that’s used). So __getattr__ delegates this to the wrapped optimisers __getstate__ and it should be fine (v2 might be different here) The purpose of this is to get a list of gradients to then do a cross replica sum before they are applied with step(). So it just needs to collect all the gradients and I think no gradients should be returned twice or else it might sum them twice.
The one possible issue I saw was if the OptimWrapper.layer_groups and the inner Optimizer.param_groups got out of sync. But this should only be an issue if new parameters are added that will then be in layer_groups but not param_groups. Looking at some (but not all) the fastai stuff here I think mostly they are just rearranging groups and removing layers that shouldn’t update. So this would be fine as all updated params would still be in param_groups and updating is all in step (you might just needlessly accumulate non-learnable params).
The one possible exception was the MixedPrecision stuff, but I think there’s special handling for this in torch_xla so the standard to_fp16 should not be used (in which case the XLA callback should probably error in on_train_begin if it finds that stuff on the learner).

So I don’t think fastai’s layer groups should interfere here. But perhaps @sgugger will quickly spot something I’ve missed.

Yeah, was a bit hard to follow the replies (or the code), but from some digging I think that the basic logic is:

batches = []
for batch in base_loader:
  batches.append(batch)
  if len(batches) == num_devices:
    process_batches(batches)
    batches.clear()

I also might me missing some details in the methods it’s calling that introduce extra subtleties (and this is only with the default fixed_batch_size=False).
So the actual length will depend on whether the number of batches is divisible by the number of devices and doesn’t rely on being able to do len(base_loader). So they’d have to introduce extra logic and assumptions to report length and so I can kinda see their reasons to not want to add this.

It looks like having ParallelLoader.PerDeviceLoader.__len__ as len(self._loader._loader) // len(self._loader._devices) should always be correct. Note though this is not the same as the length of the DistributedSampler so your current logic is not correct (as my previous assumption that the distributed sampler was doing the dividing was wrong). The distributed sampler is another level of parallelism. As noted in that issue on the xla:0 thing, on colab there is only one device, that one device having 8 cores (so you get xla:0/1 etc. In this case the multiprocessing samples don’t use a DistrbutedSampler as the creation of this is gated by if xm.xrt_world_size() > 1 but the world size (i.e. number of devices) will always be 1 on colab.

Thanks for investigating. Yeah v2 will probably need to rewrite the xla optimizer to take the changes into account since we wrote something different. Or change the optimizer to align more with PyTorch.
For v1, this seems all good. We’ll see for mixed precision in a second step, let’s get training working first.