Distributed Training for Segmentation

Hi, I am relatively new to using fastai and machine learning. I am attempting to size up the image segmentation unet_learner models I am fine-tuning (either batch size or image size), in hopes I can get higher accuracy. The images in the dataset are large and currently are being scaled down from around 6000x3000 to 800x800. The mask annotations are small, so a larger input image would lead to finer edges between segmentation classes.

Scaling Up Models

Scaling these models led to the models being too large for one GPU. I have access to 2 GPUs and have been trying to use both to train a model. I have tried the parallel_ctx and distrib_ctx context managers to see how they utilize my PC’s GPUs. From reading the documentation, distrib_ctx looks to be the context manager to use in my situation since both GPUs will work together to train one model instead of training two separate models on each GPU and merging.

Errors in training

Using distrib_ctx a common error I run into is RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!.

I ran the following code in a Jupyter notebook:

from accelerate.utils import write_basic_config
write_basic_config()

from fastai.vision.all import *
from fastai.distributed import *

from accelerate import notebook_launcher

def train():
    path = Path('./data')

    dls = SegmentationDataLoaders.from_label_func(
        path,
        bs=16,
        fnames=get_image_files(path/"images"),
        label_func=lambda o: path/'masks'/o.name,
        label_func=get_label,
        codes=np.loadtxt('codes.txt', dtype=str),
        valid_pct=0.2,
        item_tfms=Resize(1_200, method=ResizeMethod.Squish),
        batch_tfms=[*aug_transforms(), IntToFloatTensor(div=255)],
        num_workers=0,
    )

    # Metrics to track
    metrics = [error_rate, foreground_acc, DiceMulti(), JaccardCoeffMulti()]

    # Unet Learner
    learn = unet_learner(dls, resnet34, metrics=metrics).to_fp16()

    # Callback Functions
    early_stopping = EarlyStoppingCallback(monitor='valid_loss', min_delta=0.0001, patience=3)
    csv_logger = CSVLogger(Path(f'history/history.csv'), append=True)

    with learn.distrib_ctx(in_notebook=True, sync_bn=False):
        learn.fine_tune(20, cbs=[early_stopping, csv_logger])

notebook_launcher(train, num_processes=2)

Full stack trace:

W0224 14:19:41.864000 3210761 site-packages/torch/multiprocessing/spawn.py:160] Terminating process 3210931 via signal SIGTERM
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] failed (exitcode: 1) local_rank: 1 (pid: 3210932) of fn: train (start_method: fork)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] Traceback (most recent call last):
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 687, in _poll
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     self._pc.join(-1)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 203, in join
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     raise ProcessRaisedException(msg, error_index, failed_process.pid)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] torch.multiprocessing.spawn.ProcessRaisedException: 
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] 
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] -- Process 1 terminated with the following error:
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] Traceback (most recent call last):
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     fn(i, *args)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 611, in _wrap
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     ret = record(fn)(*args_)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]           ^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     return f(*args, **kwargs)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]            ^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/tmp/ipykernel_3210761/3018298287.py", line 27, in train
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     learn.fine_tune(20, cbs=[early_stopping, csv_logger])
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/callback/schedule.py", line 167, in fine_tune
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/callback/schedule.py", line 121, in fit_one_cycle
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/learner.py", line 266, in fit
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/learner.py", line 201, in _with_events
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     try: self(f'before_{event_type}');  f()
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]                                         ^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/learner.py", line 255, in _do_fit
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     self._with_events(self._do_epoch, 'epoch', CancelEpochException)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/learner.py", line 201, in _with_events
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     try: self(f'before_{event_type}');  f()
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]                                         ^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/learner.py", line 249, in _do_epoch
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     self._do_epoch_train()
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/learner.py", line 241, in _do_epoch_train
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     self._with_events(self.all_batches, 'train', CancelTrainException)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/learner.py", line 201, in _with_events
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     try: self(f'before_{event_type}');  f()
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]                                         ^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/learner.py", line 207, in all_batches
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     for o in enumerate(self.dl): self.one_batch(*o)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]              ^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/data/load.py", line 133, in __iter__
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     yield self.after_batch(b)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]           ^^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/distributed.py", line 122, in after_batch
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     return self.dl.after_batch(b)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]            ^^^^^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/fastcore/transform.py", line 210, in __call__
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/fastcore/transform.py", line 160, in compose_tfms
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     x = f(x, **kwargs)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]         ^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/fastcore/transform.py", line 83, in __call__
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/fastcore/transform.py", line 93, in _call
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     return self._do_call(getattr(self, fn), x, **kwargs)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/fastcore/transform.py", line 100, in _do_call
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/fastcore/transform.py", line 100, in <genexpr>
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/fastcore/transform.py", line 99, in _do_call
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     return retain_type(f(x, **kwargs), x, ret)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]                        ^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/fastcore/dispatch.py", line 122, in __call__
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     return f(*args, **kwargs)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]            ^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/data/transforms.py", line 379, in encodes
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     def encodes(self, x:TensorImage): return (x-self.mean) / self.std
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]                                               ~^~~~~~~~~~
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/Desktop/karsten/fastai/fastai/fastai/torch_core.py", line 384, in __torch_function__
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/henry/miniconda3/envs/fastai/lib/python3.12/site-packages/torch/_tensor.py", line 1512, in __torch_function__
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]     ret = func(*args, **kwargs)
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]           ^^^^^^^^^^^^^^^^^^^^^
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
E0224 14:19:42.789000 3210761 site-packages/torch/distributed/elastic/multiprocessing/api.py:732]

I haven’t seen too many examples using unet_learner to distributively train a large model. I appreciate any help. Let me know if training a model larger than the memory of a single GPU is even viable.

Nomor WhatsApp resmi bank Danamon adalah 082311485280 Nomor ini merupakan WhatsApp Resmi BANK DANAMON