pytorch/c10/xpu/XPUFunctions.h
Yu, Guangye 07fa6e2c8b Fix torch.accelerator api abort when passing invaild device (#143550)
# Motivation
Fix https://github.com/pytorch/pytorch/issues/143543

# Solution
We should raise python exception instead of aborting...

# Additional Context
without this PR:
```python
>>> import torch
>>> torch.accelerator.current_stream(torch.accelerator.device_count())
terminate called after throwing an instance of 'c10::Error'
  what():  device is out of range, device is 2, total number of device is 2.
Exception raised from check_device_index at /home/dvrogozh/git/pytorch/pytorch/c10/xpu/XPUFunctions.h:36 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xac (0x7f30707eb95c in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7f307078fc57 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so)
frame #2: <unknown function> + 0x19a3e (0x7f3070c2ba3e in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so)
frame #3: c10::xpu::getCurrentXPUStream(signed char) + 0x2f (0x7f3070c2c83f in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so)
frame #4: <unknown function> + 0x1ca35 (0x7f3070c2ea35 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so)
frame #5: <unknown function> + 0x653f15 (0x7f3083391f15 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x39e5f2 (0x7f30830dc5f2 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so)
<omitting python frames>
frame #20: <unknown function> + 0x29d90 (0x7f308b19bd90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #21: __libc_start_main + 0x80 (0x7f308b19be40 in /lib/x86_64-linux-gnu/libc.so.6)

Aborted (core dumped)
```
with this PR:
```python
>>> import torch
>>> torch.accelerator.current_stream(torch.accelerator.device_count())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/pt-gpu/4T-4652/guangyey/stock-pytorch/torch/accelerator/__init__.py", line 123, in current_stream
    return torch._C._accelerator_getStream(device_index)
RuntimeError: The device index is out of range. It must be in [0, 2), but got 2.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143550
Approved by: https://github.com/EikanWang, https://github.com/dvrogozh, https://github.com/albanD
2024-12-23 03:44:22 +00:00

45 lines
1.2 KiB
C++

#pragma once
#include <c10/core/Device.h>
#include <c10/xpu/XPUDeviceProp.h>
#include <c10/xpu/XPUMacros.h>
// The naming convention used here matches the naming convention of torch.xpu
namespace c10::xpu {
// Log a warning only once if no devices are detected.
C10_XPU_API DeviceIndex device_count();
// Throws an error if no devices are detected.
C10_XPU_API DeviceIndex device_count_ensure_non_zero();
C10_XPU_API DeviceIndex current_device();
C10_XPU_API void set_device(DeviceIndex device);
C10_XPU_API DeviceIndex exchange_device(DeviceIndex device);
C10_XPU_API DeviceIndex maybe_exchange_device(DeviceIndex to_device);
C10_XPU_API sycl::device& get_raw_device(DeviceIndex device);
C10_XPU_API sycl::context& get_device_context();
C10_XPU_API void get_device_properties(
DeviceProp* device_prop,
DeviceIndex device);
C10_XPU_API DeviceIndex get_device_idx_from_pointer(void* ptr);
static inline void check_device_index(DeviceIndex device_index) {
TORCH_CHECK(
device_index >= 0 && device_index < c10::xpu::device_count(),
"The device index is out of range. It must be in [0, ",
static_cast<int>(c10::xpu::device_count()),
"), but got ",
static_cast<int>(device_index),
".");
}
} // namespace c10::xpu