From 07fa6e2c8b003319f85a469307f1b1dd73f6026c Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Thu, 19 Dec 2024 13:44:10 +0000 Subject: [PATCH] Fix torch.accelerator api abort when passing invaild device (#143550) # Motivation Fix https://github.com/pytorch/pytorch/issues/143543 # Solution We should raise python exception instead of aborting... # Additional Context without this PR: ```python >>> import torch >>> torch.accelerator.current_stream(torch.accelerator.device_count()) terminate called after throwing an instance of 'c10::Error' what(): device is out of range, device is 2, total number of device is 2. Exception raised from check_device_index at /home/dvrogozh/git/pytorch/pytorch/c10/xpu/XPUFunctions.h:36 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string, std::allocator >) + 0xac (0x7f30707eb95c in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string, std::allocator > const&) + 0xf3 (0x7f307078fc57 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so) frame #2: + 0x19a3e (0x7f3070c2ba3e in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so) frame #3: c10::xpu::getCurrentXPUStream(signed char) + 0x2f (0x7f3070c2c83f in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so) frame #4: + 0x1ca35 (0x7f3070c2ea35 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so) frame #5: + 0x653f15 (0x7f3083391f15 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so) frame #6: + 0x39e5f2 (0x7f30830dc5f2 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so) frame #20: + 0x29d90 (0x7f308b19bd90 in /lib/x86_64-linux-gnu/libc.so.6) frame #21: __libc_start_main + 0x80 (0x7f308b19be40 in /lib/x86_64-linux-gnu/libc.so.6) Aborted (core dumped) ``` with this PR: ```python >>> import torch >>> torch.accelerator.current_stream(torch.accelerator.device_count()) Traceback (most recent call last): File "", line 1, in File "/home/pt-gpu/4T-4652/guangyey/stock-pytorch/torch/accelerator/__init__.py", line 123, in current_stream return torch._C._accelerator_getStream(device_index) RuntimeError: The device index is out of range. It must be in [0, 2), but got 2. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/143550 Approved by: https://github.com/EikanWang, https://github.com/dvrogozh, https://github.com/albanD --- .../src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h | 4 ++-- aten/src/ATen/mps/MPSGuardImpl.h | 4 ++-- c10/core/impl/DeviceGuardImplInterface.h | 4 ++-- c10/core/impl/VirtualGuardImpl.h | 4 ++-- c10/cuda/impl/CUDAGuardImpl.h | 4 ++-- c10/xpu/XPUFunctions.h | 10 +++++----- c10/xpu/impl/XPUGuardImpl.h | 4 ++-- test/test_cuda.py | 4 ++++ test/test_xpu.py | 2 ++ 9 files changed, 23 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h index ade5d76b0bd..93b998a8f7f 100644 --- a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -82,7 +82,7 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI void uncheckedSetDevice(Device d) const noexcept override { C10_HIP_CHECK_WARN(hipSetDevice(d.index())); } - Stream getStream(Device d) const noexcept override { + Stream getStream(Device d) const override { return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap(); } Stream getDefaultStream(Device d) const override { @@ -94,7 +94,7 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override { return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index()); } - Stream exchangeStream(Stream s) const noexcept override { + Stream exchangeStream(Stream s) const override { HIPStreamMasqueradingAsCUDA cs(s); auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index()); setCurrentHIPStreamMasqueradingAsCUDA(cs); diff --git a/aten/src/ATen/mps/MPSGuardImpl.h b/aten/src/ATen/mps/MPSGuardImpl.h index 2d58f9d29c9..7ff2d13ceef 100644 --- a/aten/src/ATen/mps/MPSGuardImpl.h +++ b/aten/src/ATen/mps/MPSGuardImpl.h @@ -64,7 +64,7 @@ struct TORCH_API MPSGuardImpl final // TODO: Currently setting only device 0 } - Stream getStream(Device d) const noexcept override { + Stream getStream(Device d) const override { return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); } @@ -78,7 +78,7 @@ struct TORCH_API MPSGuardImpl final } // NB: These do NOT set the current device - Stream exchangeStream(Stream s) const noexcept override { + Stream exchangeStream(Stream s) const override { return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); } DeviceIndex deviceCount() const noexcept override { diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index 29aa9bc8038..523e9ad9f45 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -105,7 +105,7 @@ struct C10_API DeviceGuardImplInterface { /** * Get the current stream for a given device. */ - virtual Stream getStream(Device) const noexcept = 0; + virtual Stream getStream(Device) const = 0; /** * Get the default stream for a given device. @@ -138,7 +138,7 @@ struct C10_API DeviceGuardImplInterface { * Return the previous stream for that device. You are NOT required * to set the current device to match the device of this stream. */ - virtual Stream exchangeStream(Stream) const noexcept = 0; + virtual Stream exchangeStream(Stream) const = 0; /** * Destroys the given event. diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index b5e4ab3e01b..badcb623291 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -37,7 +37,7 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { void uncheckedSetDevice(Device d) const noexcept override { impl_->uncheckedSetDevice(d); } - Stream getStream(Device d) const noexcept override { + Stream getStream(Device d) const override { return impl_->getStream(d); } Stream getNewStream(Device d, int priority = 0) const override { @@ -50,7 +50,7 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { const override { return impl_->getStreamFromGlobalPool(d, isHighPriority); } - Stream exchangeStream(Stream s) const noexcept override { + Stream exchangeStream(Stream s) const override { return impl_->exchangeStream(s); } DeviceIndex deviceCount() const noexcept override { diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index dd81dcf51fd..244c012dcb3 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -56,7 +56,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { void uncheckedSetDevice(Device d) const noexcept override { C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index())); } - Stream getStream(Device d) const noexcept override { + Stream getStream(Device d) const override { return getCurrentCUDAStream(d.index()).unwrap(); } Stream getDefaultStream(Device d) const override { @@ -70,7 +70,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { return getStreamFromPool(isHighPriority, d.index()); } // NB: These do NOT set the current device - Stream exchangeStream(Stream s) const noexcept override { + Stream exchangeStream(Stream s) const override { CUDAStream cs(s); auto old_stream = getCurrentCUDAStream(s.device().index()); setCurrentCUDAStream(cs); diff --git a/c10/xpu/XPUFunctions.h b/c10/xpu/XPUFunctions.h index a205db0d5eb..99f305c1e1b 100644 --- a/c10/xpu/XPUFunctions.h +++ b/c10/xpu/XPUFunctions.h @@ -32,13 +32,13 @@ C10_XPU_API void get_device_properties( C10_XPU_API DeviceIndex get_device_idx_from_pointer(void* ptr); -static inline void check_device_index(DeviceIndex device) { +static inline void check_device_index(DeviceIndex device_index) { TORCH_CHECK( - device >= 0 && device < c10::xpu::device_count(), - "device is out of range, device is ", - static_cast(device), - ", total number of device is ", + device_index >= 0 && device_index < c10::xpu::device_count(), + "The device index is out of range. It must be in [0, ", static_cast(c10::xpu::device_count()), + "), but got ", + static_cast(device_index), "."); } diff --git a/c10/xpu/impl/XPUGuardImpl.h b/c10/xpu/impl/XPUGuardImpl.h index b646b21d99f..e7a6b3de9e1 100644 --- a/c10/xpu/impl/XPUGuardImpl.h +++ b/c10/xpu/impl/XPUGuardImpl.h @@ -44,7 +44,7 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { c10::xpu::set_device(d.index()); } - Stream getStream(Device d) const noexcept override { + Stream getStream(Device d) const override { return getCurrentXPUStream(d.index()).unwrap(); } @@ -58,7 +58,7 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { } // NB: These do NOT set the current device - Stream exchangeStream(Stream s) const noexcept override { + Stream exchangeStream(Stream s) const override { const XPUStream stream(s); const auto old_stream = getCurrentXPUStream(s.device().index()); setCurrentXPUStream(stream); diff --git a/test/test_cuda.py b/test/test_cuda.py index 0f761971457..655f63133ac 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -766,6 +766,10 @@ class TestCuda(TestCase): self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id) torch.accelerator.set_stream(s2) self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id) + with self.assertRaisesRegex( + RuntimeError, "device_index >= 0 && device_index < num_gpus" + ): + torch.accelerator.current_stream(torch.accelerator.device_count()) def test_record_stream(self): cycles_per_ms = get_cycles_per_ms() diff --git a/test/test_xpu.py b/test/test_xpu.py index 19cda5abae2..57929e4af7f 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -306,6 +306,8 @@ print(torch.xpu.device_count()) self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id) torch.accelerator.set_stream(s2) self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id) + with self.assertRaisesRegex(RuntimeError, "The device index is out of range"): + torch.accelerator.current_stream(torch.accelerator.device_count()) def test_generator(self): torch.manual_seed(2024)