mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Move torch.cuda's atfork handler into C++ (#29101)
Summary: Fixes https://github.com/pytorch/pytorch/issues/23401 We cannot rely on `multiprocessing.util.register_after_fork` since it is only called for processes created by the `multiprocessing` module and not `os.fork()`. Moving to `pthread_atfork` does always get called. However, I don't think it's safe to call python functions inside of the `atfork` handler so the python code has to be a bit more careful when checking `_initialized`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/29101 Differential Revision: D18355451 Pulled By: ezyang fbshipit-source-id: 4d4253a3669796212c099dad4e5bdfdb0df40469
This commit is contained in:
parent
be757957ba
commit
bb119d957e
4 changed files with 52 additions and 43 deletions
|
|
@ -1,8 +1,8 @@
|
|||
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
|
||||
index 8450f27812..1de27a5b0d 100644
|
||||
index 21227ba9c0..994c778c74 100644
|
||||
--- a/torch/cuda/__init__.py
|
||||
+++ b/torch/cuda/__init__.py
|
||||
@@ -144,8 +144,6 @@ def _lazy_call(callable):
|
||||
@@ -148,8 +148,6 @@ def _lazy_call(callable):
|
||||
# Don't store the actual traceback to avoid memory cycle
|
||||
_queued_calls.append((callable, traceback.format_stack()))
|
||||
|
||||
|
|
@ -11,13 +11,13 @@ index 8450f27812..1de27a5b0d 100644
|
|||
|
||||
class DeferredCudaCallError(Exception):
|
||||
pass
|
||||
@@ -191,9 +189,6 @@ def _lazy_init():
|
||||
@@ -195,9 +193,6 @@ def _lazy_init():
|
||||
"Cannot re-initialize CUDA in forked subprocess. " + msg)
|
||||
_check_driver()
|
||||
torch._C._cuda_init()
|
||||
- _cudart = _load_cudart()
|
||||
- _cudart.cudaGetErrorName.restype = ctypes.c_char_p
|
||||
- _cudart.cudaGetErrorString.restype = ctypes.c_char_p
|
||||
_original_pid = os.getpid()
|
||||
# Some of the queued calls may reentrantly call _lazy_init();
|
||||
# we need to just return without initializing in that case.
|
||||
# However, we must not let any *other* threads in!
|
||||
|
|
|
|||
|
|
@ -22,9 +22,33 @@
|
|||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#ifndef WIN32
|
||||
#include <pthread.h>
|
||||
#endif
|
||||
|
||||
using namespace torch;
|
||||
|
||||
THCState *state;
|
||||
THCState *state = nullptr;
|
||||
static bool in_bad_fork = false; // True for children forked after cuda init
|
||||
|
||||
#ifndef WIN32
|
||||
// Called in the forked child if cuda has already been initialized
|
||||
static void forked_child() {
|
||||
in_bad_fork = true;
|
||||
utils::set_run_yet_variable_to_false();
|
||||
state = nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Should be called before the first cuda call.
|
||||
// Note: This is distinct from initExtension because a stub cuda implementation
|
||||
// has some working functions (e.g. device_count) but cannot fully initialize.
|
||||
static void poison_fork() {
|
||||
#ifndef WIN32
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, []{ pthread_atfork(nullptr, nullptr, forked_child); });
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA management methods
|
||||
|
|
@ -61,16 +85,14 @@ PyObject * THCPModule_getDevice_wrap(PyObject *self, PyObject *noargs)
|
|||
PyObject * THCPModule_getDeviceCount_wrap(PyObject *self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
//torch::utils::cuda_lazy_init();
|
||||
poison_fork();
|
||||
return PyLong_FromLong(at::cuda::device_count());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_set_run_yet_variable_to_false_wrap(PyObject *self, PyObject *noargs)
|
||||
{
|
||||
static PyObject * THCPModule_isInBadFork(PyObject *self, PyObject *noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
torch::utils::set_run_yet_variable_to_false();
|
||||
Py_RETURN_NONE;
|
||||
return PyBool_FromLong(in_bad_fork);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
@ -373,6 +395,8 @@ static void bindCudaDeviceProperties(PyObject* module) {
|
|||
static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
|
||||
poison_fork();
|
||||
state = at::globalContext().lazyInitCUDA();
|
||||
|
||||
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
|
||||
|
|
@ -446,8 +470,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
|
|||
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, nullptr},
|
||||
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, nullptr},
|
||||
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr},
|
||||
{"_cuda_set_run_yet_variable_to_false",
|
||||
(PyCFunction)THCPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr},
|
||||
{"_cuda_isInBadFork", (PyCFunction)THCPModule_isInBadFork, METH_NOARGS, nullptr},
|
||||
{"_cuda_getCurrentStream",
|
||||
(PyCFunction)THCPModule_getCurrentStream_wrap, METH_O, nullptr},
|
||||
{"_cuda_getDefaultStream",
|
||||
|
|
|
|||
|
|
@ -19,15 +19,14 @@ import warnings
|
|||
import threading
|
||||
from torch._six import raise_from
|
||||
from subprocess import Popen, PIPE
|
||||
from multiprocessing.util import register_after_fork as _register_after_fork
|
||||
from ._utils import _get_device_index
|
||||
import torch._C
|
||||
|
||||
_initialized = False
|
||||
_tls = threading.local()
|
||||
_initialization_lock = threading.Lock()
|
||||
_queued_calls = [] # don't invoke these until initialization occurs
|
||||
_in_bad_fork = False # this global is also used in torch.manual_seed
|
||||
_original_pid = False
|
||||
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
|
||||
_cudart = None
|
||||
|
||||
|
||||
|
|
@ -137,8 +136,13 @@ def _check_capability():
|
|||
warnings.warn(incorrect_binary_warn % (d, name, 10000, CUDA_VERSION))
|
||||
|
||||
|
||||
def is_initialized():
|
||||
r"""Returns whether PyTorch's CUDA state has been initialized."""
|
||||
return _initialized and not _is_in_bad_fork()
|
||||
|
||||
|
||||
def _lazy_call(callable):
|
||||
if _initialized:
|
||||
if is_initialized():
|
||||
callable()
|
||||
else:
|
||||
# Don't store the actual traceback to avoid memory cycle
|
||||
|
|
@ -151,11 +155,6 @@ class DeferredCudaCallError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def is_initialized():
|
||||
r"""Returns whether PyTorch's CUDA state has been initialized."""
|
||||
return _initialized
|
||||
|
||||
|
||||
def init():
|
||||
r"""Initialize PyTorch's CUDA state. You may need to call
|
||||
this explicitly if you are interacting with PyTorch via
|
||||
|
|
@ -170,8 +169,8 @@ def init():
|
|||
|
||||
|
||||
def _lazy_init():
|
||||
global _initialized, _cudart, _original_pid, _queued_calls
|
||||
if _initialized or hasattr(_tls, 'is_initializing'):
|
||||
global _initialized, _cudart, _queued_calls
|
||||
if is_initialized() or hasattr(_tls, 'is_initializing'):
|
||||
return
|
||||
with _initialization_lock:
|
||||
# We be double-checked locking, boys! This is OK because
|
||||
|
|
@ -179,12 +178,12 @@ def _lazy_init():
|
|||
# is for when a thread blocked on some other thread which was
|
||||
# doing the initialization; when they get the lock, they will
|
||||
# find there is nothing left to do.
|
||||
if _initialized:
|
||||
if is_initialized():
|
||||
return
|
||||
# It is important to prevent other threads from entering _lazy_init
|
||||
# immediately, while we are still guaranteed to have the GIL, because some
|
||||
# of the C calls we make below will release the GIL
|
||||
if _in_bad_fork:
|
||||
if _is_in_bad_fork():
|
||||
from sys import version_info
|
||||
if version_info < (3, 4):
|
||||
msg = ("To use CUDA with multiprocessing, you must use Python "
|
||||
|
|
@ -199,7 +198,6 @@ def _lazy_init():
|
|||
_cudart = _load_cudart()
|
||||
_cudart.cudaGetErrorName.restype = ctypes.c_char_p
|
||||
_cudart.cudaGetErrorString.restype = ctypes.c_char_p
|
||||
_original_pid = os.getpid()
|
||||
# Some of the queued calls may reentrantly call _lazy_init();
|
||||
# we need to just return without initializing in that case.
|
||||
# However, we must not let any *other* threads in!
|
||||
|
|
@ -217,17 +215,6 @@ def _lazy_init():
|
|||
_initialized = True
|
||||
|
||||
|
||||
def _after_fork(arg):
|
||||
global _initialized, _in_bad_fork
|
||||
if _initialized and _original_pid != os.getpid():
|
||||
_initialized = False
|
||||
_in_bad_fork = True
|
||||
_CudaBase.__new__ = _lazy_new
|
||||
torch._C._cuda_set_run_yet_variable_to_false()
|
||||
|
||||
_register_after_fork(_after_fork, _after_fork)
|
||||
|
||||
|
||||
def cudart():
|
||||
_lazy_init()
|
||||
return _cudart
|
||||
|
|
@ -335,8 +322,7 @@ def get_device_capability(device=None):
|
|||
|
||||
|
||||
def get_device_properties(device):
|
||||
if not _initialized:
|
||||
init() # will define _get_device_properties and _CudaDeviceProperties
|
||||
_lazy_init() # will define _get_device_properties and _CudaDeviceProperties
|
||||
device = _get_device_index(device, optional=True)
|
||||
if device < 0 or device >= device_count():
|
||||
raise AssertionError("Invalid device id")
|
||||
|
|
@ -489,8 +475,8 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'):
|
|||
@staticmethod
|
||||
def _lazy_new(cls, *args, **kwargs):
|
||||
_lazy_init()
|
||||
# We need this method only for lazy init, so we can remove it
|
||||
del _CudaBase.__new__
|
||||
# We may need to call lazy init again if we are a forked child
|
||||
# del _CudaBase.__new__
|
||||
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def manual_seed(seed):
|
|||
seed = int(seed)
|
||||
import torch.cuda
|
||||
|
||||
if not torch.cuda._in_bad_fork:
|
||||
if not torch.cuda._is_in_bad_fork():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
return default_generator.manual_seed(seed)
|
||||
|
|
@ -41,7 +41,7 @@ def seed():
|
|||
seed = default_generator.seed()
|
||||
import torch.cuda
|
||||
|
||||
if not torch.cuda._in_bad_fork:
|
||||
if not torch.cuda._is_in_bad_fork():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
return seed
|
||||
|
|
|
|||
Loading…
Reference in a new issue