Distributed get_preds CUDA error

Good morning,
I am running a distributed pipeline which essentially does the following:

  1. Loop over data folds
  2. Create dataloaders (dls) + learner
  3. Train learner with fine_tune with distrib_ctx
  4. Compute test predictions with get_preds in distrib_ctx on a new dataloader created with dls.test_dl

This raises a CUDA error for which I was not able to find equivalent on the web yet except for an unresolved github issue for the lightning library.

The simplest throwing example I came up with uses fastai/fastai/launch.py and fastai/nbs/examples/train_imagenette.py.
The only modifications you need to bring to these files are as follow:

  1. Replace from fastscript import * with from fastcore.script import *
  2. Add preds = learn.get_preds() right after learn.fit_flat_cos within the same context

Then you can create this error py launching python -m fastai.launch.py fastai/nbs/examples/train_imagenette.py --epochs 1 --size 64 --bs 128 --runs 12. On my machine (2x GeForce GTX 1080 Ti) it runs fine until the 7th run after which the following CUDA error is raised:

terminate called after throwing an instance of 'c10::Error’2.3581]
what(): CUDA error: initialization error
Exception raised from insert_events at /opt/conda/conda-bld/pytorch_1603729062494/work/c10/cuda/CUDACachingAllocator.cpp:717 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f15ddea58b2 in /home/navarinilab/miniconda3/envs/fastai/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x1070 (0x7f15de0f7f20 in /home/navarinilab/miniconda3/envs/fastai/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::TensorImpl::release_resources() + 0x4d (0x7f15dde90b7d in /home/navarinilab/miniconda3/envs/fastai/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #3: + 0x5f65b2 (0x7f161f65a5b2 in /home/navarinilab/miniconda3/envs/fastai/lib/python3.8/site-packages/torch/lib/libtorch_python.so)

Do you have any ideas how to solve this error?
Thanks in advance.

I found a workaround: instead of using get_preds, I used bare pytorch code to get the predictions.
I guess the issue comes from the fact that get_preds is keeping some data on the GPUs causing the memory error.

The solution was to follow this post instructions, mainly:

  • Delete dls and learner
  • Empty cuda cache
  • Call garbage collector