mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
initialize device when pinning memory on this device, short circuit i… (#145752)
…s_pinned if device is not initialized Do not land RFC potential fix for #144687 Now `.is_pinned(device="cuda")` does not initialize device and thus doesn't poison the fork (but it complains about `device` arg being deprecated). To not need `device=` arg we'd need to fix get_accelerator to not initialize device. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145752 Approved by: https://github.com/albanD Co-authored-by: albanD <albandes@fb.com>
This commit is contained in:
parent
1252c1933d
commit
08ff11e9d0
4 changed files with 58 additions and 4 deletions
|
|
@ -101,6 +101,10 @@ class TORCH_API Context {
|
|||
opt_device_type.value())) { // passed device not an accelerator
|
||||
return false;
|
||||
}
|
||||
if (!init_[static_cast<int8_t>(opt_device_type.value())].test_once()) {
|
||||
// If the device is not initialized, no pointer can be pinned for it
|
||||
return false;
|
||||
}
|
||||
return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,14 +17,19 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
|
|||
// in the same PyTorch build, you would ONLY ever get the CUDA pinned memory allocator.
|
||||
// To properly support this, see https://github.com/pytorch/pytorch/issues/14560
|
||||
if (at::globalContext().hasCUDA()) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
return at::detail::getCUDAHooks().getPinnedMemoryAllocator();
|
||||
} else if (at::globalContext().hasMTIA()) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::MTIA);
|
||||
return at::detail::getMTIAHooks().getPinnedMemoryAllocator();
|
||||
} else if (at::globalContext().hasXPU()) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
|
||||
return at::detail::getXPUHooks().getPinnedMemoryAllocator();
|
||||
} else if (at::globalContext().hasHPU()) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::HPU);
|
||||
return at::detail::getHPUHooks().getPinnedMemoryAllocator();
|
||||
} else if(at::isPrivateUse1HooksRegistered()) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1);
|
||||
return at::detail::getPrivateUse1Hooks().getPinnedMemoryAllocator();
|
||||
} else {
|
||||
TORCH_CHECK(false, "Need to provide pin_memory allocator to use pin memory.")
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ class once_flag {
|
|||
once_flag(once_flag&&) = delete;
|
||||
once_flag& operator=(once_flag&&) = delete;
|
||||
~once_flag() = default;
|
||||
bool test_once() {
|
||||
return init_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename Flag, typename F, typename... Args>
|
||||
|
|
@ -55,10 +58,6 @@ class once_flag {
|
|||
init_.store(true, std::memory_order_release);
|
||||
}
|
||||
|
||||
bool test_once() {
|
||||
return init_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
void reset_once() {
|
||||
init_.store(false, std::memory_order_release);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3428,6 +3428,52 @@ print(f"{{r1}}, {{r2}}")
|
|||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
|
||||
|
||||
def test_is_pinned_no_context(self):
|
||||
test_script = """\
|
||||
import torch
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def fork_and_check_is_pinned():
|
||||
# Create a pipe to communicate between parent and child processes
|
||||
parent_conn, child_conn = multiprocessing.Pipe()
|
||||
|
||||
def worker(conn):
|
||||
try:
|
||||
x = torch.randn(10)
|
||||
x.is_pinned(device="cuda")
|
||||
x = torch.ones(10, device="cuda")[0].item()
|
||||
conn.send(x)
|
||||
except Exception as e:
|
||||
conn.send(str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
# Fork a new process
|
||||
p = multiprocessing.Process(target=worker, args=(child_conn,))
|
||||
p.start()
|
||||
# Receive the result from the child process
|
||||
result = parent_conn.recv()
|
||||
parent_conn.close()
|
||||
# Wait for the child process to finish
|
||||
p.join()
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
raise RuntimeError(result)
|
||||
return result
|
||||
|
||||
x = torch.randn(10)
|
||||
# check that is_pinned won't poison future fork
|
||||
x.is_pinned(device="cuda")
|
||||
ret = fork_and_check_is_pinned()
|
||||
print(ret)
|
||||
|
||||
"""
|
||||
r = (
|
||||
subprocess.check_output([sys.executable, "-c", test_script])
|
||||
.decode("ascii")
|
||||
.strip()
|
||||
)
|
||||
self.assertEqual(r, "1.0")
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
|
|
|
|||
Loading…
Reference in a new issue