Why can’t I just wrap CUDA memory errors in an exception?
I’ve been digging into this for a couple of days, first I went to the pytorch.cuda source code and found this:
def synchronize(self):
“”“Wait for all the kernels in this stream to complete.
… note:: This is a wrapper around cudaStreamSynchronize()
: see
CUDA documentation
_ for more info.
… _CUDA documentation:
CUDA Runtime API :: CUDA Toolkit Documentation
“””
check_error(cudart().cudaStreamSynchronize(self))
Ah ha, so pytorch uses the "CUDA runtime API." So after signing on to the CUDA forums and searching for “memory errors”,“exceptions”,etc; I found exactly nothing relevant. Then I decided to read the runtime API notes and found this:
Context management
Context management can be done through the driver API, but is not exposed in the runtime API. Instead, the runtime API decides itself which context to use for a thread: if a context has been made current to the calling thread through the driver API, the runtime will use that, but if there is no such context, it uses a “primary context.” Primary contexts are created as needed, one per device per process, are reference-counted, and are then destroyed when there are no more references to them. Within one process, all users of the runtime API will share the primary context, unless a context has been made current to each thread. The context that the runtime uses, i.e, either the current context or primary context, can be synchronized with cudaDeviceSynchronize(), and destroyed with cudaDeviceReset().
Using the runtime API with primary contexts has its tradeoffs, however. It can cause trouble for users writing plug-ins for larger software packages, for example, because if all plug-ins run in the same process, they will all share a context but will likely have no way to communicate with each other. So, if one of them calls cudaDeviceReset() after finishing all its CUDA work, the other plug-ins will fail because the context they were using was destroyed without their knowledge. To avoid this issue, CUDA clients can use the driver API to create and set the current context, and then use the runtime API to work with it. However, contexts may consume significant resources, such as device memory, extra host threads, and performance costs of context switching on the device. This runtime-driver context sharing is important when using the driver API in conjunction with libraries built on the runtime API, such as cuBLAS or cuFFT.
My guess is that in the case of Windows you’re really using a DLL bound thru COM and in Linux, a Shared object with a wrapper. I think that Python and Pytorch never really have the level of control necessary to recover from the memory error. Once the “context” is hosed, you’re hosed until you do a cudaDeviceReset() (Kernel, Restart to us).