How to use fastai with 2 T4 GPUs on Kaggle?

I’m trying to reproduce the code from fastai documentation on distributed training on Kaggle with 2 T4 GPUs:

from accelerate.utils import write_basic_config
write_basic_config()
---
from fastai.vision.all import *
from fastai.distributed import *


def train():
    set_seed(99, True)
    path = untar_data(URLs.PETS)/'images'
    dls = ImageDataLoaders.from_name_func(
        path, get_image_files(path), valid_pct=0.2,
        label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))
    
    learn = vision_learner(dls, resnet34, metrics=error_rate).to_fp16()
    with learn.distrib_ctx(in_notebook=True):
        learn.fine_tune(1)
---
from accelerate import notebook_launcher
notebook_launcher(train, num_processes=2)

Get an error:

ProcessRaisedException                    Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/accelerate/launchers.py:201, in notebook_launcher(function, args, num_processes, mixed_precision, use_port, master_addr, node_rank, num_nodes)
    200 try:
--> 201     start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
    202 except ProcessRaisedException as e:

File /opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:202, in start_processes(fn, args, nprocs, join, daemon, start_method)
    201 # Loop on join until it returns True or raises an exception.
--> 202 while not context.join():
    203     pass

File /opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:163, in ProcessContext.join(self, timeout)
    162 msg += original_trace
--> 163 raise ProcessRaisedException(msg, error_index, failed_process.pid)

ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/launch.py", line 624, in __call__
    self.launcher(*args)
  File "/tmp/ipykernel_157/4048901777.py", line 14, in train
    learn.fine_tune(1)
  File "/opt/conda/lib/python3.10/site-packages/fastai/callback/schedule.py", line 165, in fine_tune
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastai/callback/schedule.py", line 119, in fit_one_cycle
    self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 264, in fit
    self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 199, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 253, in _do_fit
    self._with_events(self._do_epoch, 'epoch', CancelEpochException)
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 199, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 247, in _do_epoch
    self._do_epoch_train()
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 239, in _do_epoch_train
    self._with_events(self.all_batches, 'train', CancelTrainException)
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 199, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 205, in all_batches
    for o in enumerate(self.dl): self.one_batch(*o)
  File "/opt/conda/lib/python3.10/site-packages/fastai/data/load.py", line 131, in __iter__
    yield self.after_batch(b)
  File "/opt/conda/lib/python3.10/site-packages/fastai/distributed.py", line 120, in after_batch
    return self.dl.after_batch(b)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 208, in __call__
    def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 158, in compose_tfms
    x = f(x, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 81, in __call__
    def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 91, in _call
    return self._do_call(getattr(self, fn), x, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 98, in _do_call
    res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 98, in <genexpr>
    res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 97, in _do_call
    return retain_type(f(x, **kwargs), x, ret)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/dispatch.py", line 120, in __call__
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastai/data/transforms.py", line 377, in encodes
    def encodes(self, x:TensorImage): return (x-self.mean) / self.std
  File "/opt/conda/lib/python3.10/site-packages/fastai/torch_core.py", line 382, in __torch_function__
    res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 1386, in __torch_function__
    ret = func(*args, **kwargs)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!


The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[3], line 2
      1 from accelerate import notebook_launcher
----> 2 notebook_launcher(train, num_processes=2)

File /opt/conda/lib/python3.10/site-packages/accelerate/launchers.py:211, in notebook_launcher(function, args, num_processes, mixed_precision, use_port, master_addr, node_rank, num_nodes)
    204                 raise RuntimeError(
    205                     "CUDA has been initialized before the `notebook_launcher` could create a forked subprocess. "
    206                     "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
    207                     "Please review your imports and test them when running the `notebook_launcher()` to identify "
    208                     "which one is problematic and causing CUDA to be initialized."
    209                 ) from e
    210             else:
--> 211                 raise RuntimeError(f"An issue was found when launching the training: {e}") from e
    213 else:
    214     # No need for a distributed launch otherwise as it's either CPU, GPU or MPS.
    215     if is_mps_available():

RuntimeError: An issue was found when launching the training: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/launch.py", line 624, in __call__
    self.launcher(*args)
  File "/tmp/ipykernel_157/4048901777.py", line 14, in train
    learn.fine_tune(1)
  File "/opt/conda/lib/python3.10/site-packages/fastai/callback/schedule.py", line 165, in fine_tune
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastai/callback/schedule.py", line 119, in fit_one_cycle
    self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 264, in fit
    self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 199, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 253, in _do_fit
    self._with_events(self._do_epoch, 'epoch', CancelEpochException)
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 199, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 247, in _do_epoch
    self._do_epoch_train()
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 239, in _do_epoch_train
    self._with_events(self.all_batches, 'train', CancelTrainException)
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 199, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/opt/conda/lib/python3.10/site-packages/fastai/learner.py", line 205, in all_batches
    for o in enumerate(self.dl): self.one_batch(*o)
  File "/opt/conda/lib/python3.10/site-packages/fastai/data/load.py", line 131, in __iter__
    yield self.after_batch(b)
  File "/opt/conda/lib/python3.10/site-packages/fastai/distributed.py", line 120, in after_batch
    return self.dl.after_batch(b)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 208, in __call__
    def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 158, in compose_tfms
    x = f(x, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 81, in __call__
    def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 91, in _call
    return self._do_call(getattr(self, fn), x, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 98, in _do_call
    res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 98, in <genexpr>
    res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/transform.py", line 97, in _do_call
    return retain_type(f(x, **kwargs), x, ret)
  File "/opt/conda/lib/python3.10/site-packages/fastcore/dispatch.py", line 120, in __call__
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/fastai/data/transforms.py", line 377, in encodes
    def encodes(self, x:TensorImage): return (x-self.mean) / self.std
  File "/opt/conda/lib/python3.10/site-packages/fastai/torch_core.py", line 382, in __torch_function__
    res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 1386, in __torch_function__
    ret = func(*args, **kwargs)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

What might be the problem and how to fix it?