initialize device when pinning memory on this device, short circuit i… (#145752)

…s_pinned if device is not initialized
Do not land
RFC
potential fix for #144687

Now `.is_pinned(device="cuda")` does not initialize device and thus doesn't poison the fork (but it complains about `device` arg being deprecated). To not need `device=` arg we'd need to fix get_accelerator to not initialize device.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145752
Approved by: https://github.com/albanD

Co-authored-by: albanD <albandes@fb.com>
This commit is contained in:
Natalia Gimelshein 2025-01-30 21:37:29 +00:00 committed by PyTorch MergeBot
parent 1252c1933d
commit 08ff11e9d0
4 changed files with 58 additions and 4 deletions

View file

@ -101,6 +101,10 @@ class TORCH_API Context {
opt_device_type.value())) { // passed device not an accelerator
return false;
}
if (!init_[static_cast<int8_t>(opt_device_type.value())].test_once()) {
// If the device is not initialized, no pointer can be pinned for it
return false;
}
return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data);
}

View file

@ -17,14 +17,19 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
// in the same PyTorch build, you would ONLY ever get the CUDA pinned memory allocator.
// To properly support this, see https://github.com/pytorch/pytorch/issues/14560
if (at::globalContext().hasCUDA()) {
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
return at::detail::getCUDAHooks().getPinnedMemoryAllocator();
} else if (at::globalContext().hasMTIA()) {
at::globalContext().lazyInitDevice(c10::DeviceType::MTIA);
return at::detail::getMTIAHooks().getPinnedMemoryAllocator();
} else if (at::globalContext().hasXPU()) {
at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
return at::detail::getXPUHooks().getPinnedMemoryAllocator();
} else if (at::globalContext().hasHPU()) {
at::globalContext().lazyInitDevice(c10::DeviceType::HPU);
return at::detail::getHPUHooks().getPinnedMemoryAllocator();
} else if(at::isPrivateUse1HooksRegistered()) {
at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1);
return at::detail::getPrivateUse1Hooks().getPinnedMemoryAllocator();
} else {
TORCH_CHECK(false, "Need to provide pin_memory allocator to use pin memory.")

View file

@ -40,6 +40,9 @@ class once_flag {
once_flag(once_flag&&) = delete;
once_flag& operator=(once_flag&&) = delete;
~once_flag() = default;
bool test_once() {
return init_.load(std::memory_order_acquire);
}
private:
template <typename Flag, typename F, typename... Args>
@ -55,10 +58,6 @@ class once_flag {
init_.store(true, std::memory_order_release);
}
bool test_once() {
return init_.load(std::memory_order_acquire);
}
void reset_once() {
init_.store(false, std::memory_order_release);
}

View file

@ -3428,6 +3428,52 @@ print(f"{{r1}}, {{r2}}")
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
def test_is_pinned_no_context(self):
test_script = """\
import torch
import multiprocessing
def fork_and_check_is_pinned():
# Create a pipe to communicate between parent and child processes
parent_conn, child_conn = multiprocessing.Pipe()
def worker(conn):
try:
x = torch.randn(10)
x.is_pinned(device="cuda")
x = torch.ones(10, device="cuda")[0].item()
conn.send(x)
except Exception as e:
conn.send(str(e))
finally:
conn.close()
# Fork a new process
p = multiprocessing.Process(target=worker, args=(child_conn,))
p.start()
# Receive the result from the child process
result = parent_conn.recv()
parent_conn.close()
# Wait for the child process to finish
p.join()
if isinstance(result, str) and result.startswith("Error"):
raise RuntimeError(result)
return result
x = torch.randn(10)
# check that is_pinned won't poison future fork
x.is_pinned(device="cuda")
ret = fork_and_check_is_pinned()
print(ret)
"""
r = (
subprocess.check_output([sys.executable, "-c", test_script])
.decode("ascii")
.strip()
)
self.assertEqual(r, "1.0")
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
@torch.testing._internal.common_utils.markDynamoStrictTest