From 08ff11e9d04aa6861ae685ce4d9b3e6853db2a73 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Thu, 30 Jan 2025 21:37:29 +0000 Subject: [PATCH] =?UTF-8?q?initialize=20device=20when=20pinning=20memory?= =?UTF-8?q?=20on=20this=20device,=20short=20circuit=20i=E2=80=A6=20(#14575?= =?UTF-8?q?2)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …s_pinned if device is not initialized Do not land RFC potential fix for #144687 Now `.is_pinned(device="cuda")` does not initialize device and thus doesn't poison the fork (but it complains about `device` arg being deprecated). To not need `device=` arg we'd need to fix get_accelerator to not initialize device. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145752 Approved by: https://github.com/albanD Co-authored-by: albanD --- aten/src/ATen/Context.h | 4 +++ aten/src/ATen/EmptyTensor.cpp | 5 ++++ c10/util/CallOnce.h | 7 +++--- test/test_cuda.py | 46 +++++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index f986774e3fa..6edd8dad69f 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -101,6 +101,10 @@ class TORCH_API Context { opt_device_type.value())) { // passed device not an accelerator return false; } + if (!init_[static_cast(opt_device_type.value())].test_once()) { + // If the device is not initialized, no pointer can be pinned for it + return false; + } return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data); } diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index 3f1871086ee..e29f738e87a 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -17,14 +17,19 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { // in the same PyTorch build, you would ONLY ever get the CUDA pinned memory allocator. // To properly support this, see https://github.com/pytorch/pytorch/issues/14560 if (at::globalContext().hasCUDA()) { + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); return at::detail::getCUDAHooks().getPinnedMemoryAllocator(); } else if (at::globalContext().hasMTIA()) { + at::globalContext().lazyInitDevice(c10::DeviceType::MTIA); return at::detail::getMTIAHooks().getPinnedMemoryAllocator(); } else if (at::globalContext().hasXPU()) { + at::globalContext().lazyInitDevice(c10::DeviceType::XPU); return at::detail::getXPUHooks().getPinnedMemoryAllocator(); } else if (at::globalContext().hasHPU()) { + at::globalContext().lazyInitDevice(c10::DeviceType::HPU); return at::detail::getHPUHooks().getPinnedMemoryAllocator(); } else if(at::isPrivateUse1HooksRegistered()) { + at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); return at::detail::getPrivateUse1Hooks().getPinnedMemoryAllocator(); } else { TORCH_CHECK(false, "Need to provide pin_memory allocator to use pin memory.") diff --git a/c10/util/CallOnce.h b/c10/util/CallOnce.h index c42436e39c8..2d8c9dc5b11 100644 --- a/c10/util/CallOnce.h +++ b/c10/util/CallOnce.h @@ -40,6 +40,9 @@ class once_flag { once_flag(once_flag&&) = delete; once_flag& operator=(once_flag&&) = delete; ~once_flag() = default; + bool test_once() { + return init_.load(std::memory_order_acquire); + } private: template @@ -55,10 +58,6 @@ class once_flag { init_.store(true, std::memory_order_release); } - bool test_once() { - return init_.load(std::memory_order_acquire); - } - void reset_once() { init_.store(false, std::memory_order_release); } diff --git a/test/test_cuda.py b/test/test_cuda.py index 1910a9687bb..8377265ede4 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3428,6 +3428,52 @@ print(f"{{r1}}, {{r2}}") with self.assertRaisesRegex(RuntimeError, error_msg): torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) + def test_is_pinned_no_context(self): + test_script = """\ +import torch +import multiprocessing + + +def fork_and_check_is_pinned(): + # Create a pipe to communicate between parent and child processes + parent_conn, child_conn = multiprocessing.Pipe() + + def worker(conn): + try: + x = torch.randn(10) + x.is_pinned(device="cuda") + x = torch.ones(10, device="cuda")[0].item() + conn.send(x) + except Exception as e: + conn.send(str(e)) + finally: + conn.close() + # Fork a new process + p = multiprocessing.Process(target=worker, args=(child_conn,)) + p.start() + # Receive the result from the child process + result = parent_conn.recv() + parent_conn.close() + # Wait for the child process to finish + p.join() + if isinstance(result, str) and result.startswith("Error"): + raise RuntimeError(result) + return result + +x = torch.randn(10) +# check that is_pinned won't poison future fork +x.is_pinned(device="cuda") +ret = fork_and_check_is_pinned() +print(ret) + +""" + r = ( + subprocess.check_output([sys.executable, "-c", test_script]) + .decode("ascii") + .strip() + ) + self.assertEqual(r, "1.0") + @unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") @torch.testing._internal.common_utils.markDynamoStrictTest