Move torch.cuda's atfork handler into C++ (#29101)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/23401

We cannot rely on `multiprocessing.util.register_after_fork` since it is only
called for processes created by the `multiprocessing` module and not `os.fork()`.

Moving to `pthread_atfork` does always get called. However, I don't think it's safe to call python functions inside of the `atfork` handler so the python code has to be a bit more careful when checking `_initialized`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29101

Differential Revision: D18355451

Pulled By: ezyang

fbshipit-source-id: 4d4253a3669796212c099dad4e5bdfdb0df40469
This commit is contained in:
Peter Bell 2019-11-11 07:32:28 -08:00 committed by Facebook Github Bot
parent be757957ba
commit bb119d957e
4 changed files with 52 additions and 43 deletions

View file

@ -1,8 +1,8 @@
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index 8450f27812..1de27a5b0d 100644
index 21227ba9c0..994c778c74 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -144,8 +144,6 @@ def _lazy_call(callable):
@@ -148,8 +148,6 @@ def _lazy_call(callable):
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))
@ -11,13 +11,13 @@ index 8450f27812..1de27a5b0d 100644
class DeferredCudaCallError(Exception):
pass
@@ -191,9 +189,6 @@ def _lazy_init():
@@ -195,9 +193,6 @@ def _lazy_init():
"Cannot re-initialize CUDA in forked subprocess. " + msg)
_check_driver()
torch._C._cuda_init()
- _cudart = _load_cudart()
- _cudart.cudaGetErrorName.restype = ctypes.c_char_p
- _cudart.cudaGetErrorString.restype = ctypes.c_char_p
_original_pid = os.getpid()
# Some of the queued calls may reentrantly call _lazy_init();
# we need to just return without initializing in that case.
# However, we must not let any *other* threads in!

View file

@ -22,9 +22,33 @@
#include <torch/csrc/Generator.h>
#include <torch/csrc/python_headers.h>
#ifndef WIN32
#include <pthread.h>
#endif
using namespace torch;
THCState *state;
THCState *state = nullptr;
static bool in_bad_fork = false; // True for children forked after cuda init
#ifndef WIN32
// Called in the forked child if cuda has already been initialized
static void forked_child() {
in_bad_fork = true;
utils::set_run_yet_variable_to_false();
state = nullptr;
}
#endif
// Should be called before the first cuda call.
// Note: This is distinct from initExtension because a stub cuda implementation
// has some working functions (e.g. device_count) but cannot fully initialize.
static void poison_fork() {
#ifndef WIN32
static std::once_flag flag;
std::call_once(flag, []{ pthread_atfork(nullptr, nullptr, forked_child); });
#endif
}
////////////////////////////////////////////////////////////////////////////////
// CUDA management methods
@ -61,16 +85,14 @@ PyObject * THCPModule_getDevice_wrap(PyObject *self, PyObject *noargs)
PyObject * THCPModule_getDeviceCount_wrap(PyObject *self, PyObject *noargs)
{
HANDLE_TH_ERRORS
//torch::utils::cuda_lazy_init();
poison_fork();
return PyLong_FromLong(at::cuda::device_count());
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_set_run_yet_variable_to_false_wrap(PyObject *self, PyObject *noargs)
{
static PyObject * THCPModule_isInBadFork(PyObject *self, PyObject *noargs) {
HANDLE_TH_ERRORS
torch::utils::set_run_yet_variable_to_false();
Py_RETURN_NONE;
return PyBool_FromLong(in_bad_fork);
END_HANDLE_TH_ERRORS
}
@ -373,6 +395,8 @@ static void bindCudaDeviceProperties(PyObject* module) {
static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs)
{
HANDLE_TH_ERRORS
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
poison_fork();
state = at::globalContext().lazyInitCUDA();
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
@ -446,8 +470,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, nullptr},
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, nullptr},
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr},
{"_cuda_set_run_yet_variable_to_false",
(PyCFunction)THCPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr},
{"_cuda_isInBadFork", (PyCFunction)THCPModule_isInBadFork, METH_NOARGS, nullptr},
{"_cuda_getCurrentStream",
(PyCFunction)THCPModule_getCurrentStream_wrap, METH_O, nullptr},
{"_cuda_getDefaultStream",

View file

@ -19,15 +19,14 @@ import warnings
import threading
from torch._six import raise_from
from subprocess import Popen, PIPE
from multiprocessing.util import register_after_fork as _register_after_fork
from ._utils import _get_device_index
import torch._C
_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls = [] # don't invoke these until initialization occurs
_in_bad_fork = False # this global is also used in torch.manual_seed
_original_pid = False
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
_cudart = None
@ -137,8 +136,13 @@ def _check_capability():
warnings.warn(incorrect_binary_warn % (d, name, 10000, CUDA_VERSION))
def is_initialized():
r"""Returns whether PyTorch's CUDA state has been initialized."""
return _initialized and not _is_in_bad_fork()
def _lazy_call(callable):
if _initialized:
if is_initialized():
callable()
else:
# Don't store the actual traceback to avoid memory cycle
@ -151,11 +155,6 @@ class DeferredCudaCallError(Exception):
pass
def is_initialized():
r"""Returns whether PyTorch's CUDA state has been initialized."""
return _initialized
def init():
r"""Initialize PyTorch's CUDA state. You may need to call
this explicitly if you are interacting with PyTorch via
@ -170,8 +169,8 @@ def init():
def _lazy_init():
global _initialized, _cudart, _original_pid, _queued_calls
if _initialized or hasattr(_tls, 'is_initializing'):
global _initialized, _cudart, _queued_calls
if is_initialized() or hasattr(_tls, 'is_initializing'):
return
with _initialization_lock:
# We be double-checked locking, boys! This is OK because
@ -179,12 +178,12 @@ def _lazy_init():
# is for when a thread blocked on some other thread which was
# doing the initialization; when they get the lock, they will
# find there is nothing left to do.
if _initialized:
if is_initialized():
return
# It is important to prevent other threads from entering _lazy_init
# immediately, while we are still guaranteed to have the GIL, because some
# of the C calls we make below will release the GIL
if _in_bad_fork:
if _is_in_bad_fork():
from sys import version_info
if version_info < (3, 4):
msg = ("To use CUDA with multiprocessing, you must use Python "
@ -199,7 +198,6 @@ def _lazy_init():
_cudart = _load_cudart()
_cudart.cudaGetErrorName.restype = ctypes.c_char_p
_cudart.cudaGetErrorString.restype = ctypes.c_char_p
_original_pid = os.getpid()
# Some of the queued calls may reentrantly call _lazy_init();
# we need to just return without initializing in that case.
# However, we must not let any *other* threads in!
@ -217,17 +215,6 @@ def _lazy_init():
_initialized = True
def _after_fork(arg):
global _initialized, _in_bad_fork
if _initialized and _original_pid != os.getpid():
_initialized = False
_in_bad_fork = True
_CudaBase.__new__ = _lazy_new
torch._C._cuda_set_run_yet_variable_to_false()
_register_after_fork(_after_fork, _after_fork)
def cudart():
_lazy_init()
return _cudart
@ -335,8 +322,7 @@ def get_device_capability(device=None):
def get_device_properties(device):
if not _initialized:
init() # will define _get_device_properties and _CudaDeviceProperties
_lazy_init() # will define _get_device_properties and _CudaDeviceProperties
device = _get_device_index(device, optional=True)
if device < 0 or device >= device_count():
raise AssertionError("Invalid device id")
@ -489,8 +475,8 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'):
@staticmethod
def _lazy_new(cls, *args, **kwargs):
_lazy_init()
# We need this method only for lazy init, so we can remove it
del _CudaBase.__new__
# We may need to call lazy init again if we are a forked child
# del _CudaBase.__new__
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)

View file

@ -28,7 +28,7 @@ def manual_seed(seed):
seed = int(seed)
import torch.cuda
if not torch.cuda._in_bad_fork:
if not torch.cuda._is_in_bad_fork():
torch.cuda.manual_seed_all(seed)
return default_generator.manual_seed(seed)
@ -41,7 +41,7 @@ def seed():
seed = default_generator.seed()
import torch.cuda
if not torch.cuda._in_bad_fork:
if not torch.cuda._is_in_bad_fork():
torch.cuda.manual_seed_all(seed)
return seed