mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
# Motivation Fix https://github.com/pytorch/pytorch/issues/143543 # Solution We should raise python exception instead of aborting... # Additional Context without this PR: ```python >>> import torch >>> torch.accelerator.current_stream(torch.accelerator.device_count()) terminate called after throwing an instance of 'c10::Error' what(): device is out of range, device is 2, total number of device is 2. Exception raised from check_device_index at /home/dvrogozh/git/pytorch/pytorch/c10/xpu/XPUFunctions.h:36 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xac (0x7f30707eb95c in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7f307078fc57 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so) frame #2: <unknown function> + 0x19a3e (0x7f3070c2ba3e in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so) frame #3: c10::xpu::getCurrentXPUStream(signed char) + 0x2f (0x7f3070c2c83f in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so) frame #4: <unknown function> + 0x1ca35 (0x7f3070c2ea35 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so) frame #5: <unknown function> + 0x653f15 (0x7f3083391f15 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so) frame #6: <unknown function> + 0x39e5f2 (0x7f30830dc5f2 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so) <omitting python frames> frame #20: <unknown function> + 0x29d90 (0x7f308b19bd90 in /lib/x86_64-linux-gnu/libc.so.6) frame #21: __libc_start_main + 0x80 (0x7f308b19be40 in /lib/x86_64-linux-gnu/libc.so.6) Aborted (core dumped) ``` with this PR: ```python >>> import torch >>> torch.accelerator.current_stream(torch.accelerator.device_count()) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/pt-gpu/4T-4652/guangyey/stock-pytorch/torch/accelerator/__init__.py", line 123, in current_stream return torch._C._accelerator_getStream(device_index) RuntimeError: The device index is out of range. It must be in [0, 2), but got 2. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/143550 Approved by: https://github.com/EikanWang, https://github.com/dvrogozh, https://github.com/albanD
108 lines
3.2 KiB
C++
108 lines
3.2 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
|
|
|
namespace c10::impl {
|
|
|
|
/**
|
|
* An implementation of DeviceGuardImplInterface which delegates
|
|
* to virtual dispatch on the DeviceGuardImpl registry.
|
|
*/
|
|
class VirtualGuardImpl final : public DeviceGuardImplInterface {
|
|
public:
|
|
VirtualGuardImpl(DeviceType device_type)
|
|
: impl_(getDeviceGuardImpl(device_type)) {}
|
|
// This constructor exists purely for testing
|
|
VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {}
|
|
|
|
// Copying and moving is OK!
|
|
VirtualGuardImpl(const VirtualGuardImpl&) = default;
|
|
VirtualGuardImpl& operator=(const VirtualGuardImpl&) = default;
|
|
VirtualGuardImpl(VirtualGuardImpl&&) noexcept = default;
|
|
VirtualGuardImpl& operator=(VirtualGuardImpl&&) noexcept = default;
|
|
~VirtualGuardImpl() override = default;
|
|
|
|
DeviceType type() const override {
|
|
return impl_->type();
|
|
}
|
|
Device exchangeDevice(Device d) const override {
|
|
return impl_->exchangeDevice(d);
|
|
}
|
|
Device getDevice() const override {
|
|
return impl_->getDevice();
|
|
}
|
|
void setDevice(Device d) const override {
|
|
impl_->setDevice(d);
|
|
}
|
|
void uncheckedSetDevice(Device d) const noexcept override {
|
|
impl_->uncheckedSetDevice(d);
|
|
}
|
|
Stream getStream(Device d) const override {
|
|
return impl_->getStream(d);
|
|
}
|
|
Stream getNewStream(Device d, int priority = 0) const override {
|
|
return impl_->getNewStream(d, priority);
|
|
}
|
|
Stream getDefaultStream(Device d) const override {
|
|
return impl_->getDefaultStream(d);
|
|
}
|
|
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
|
|
const override {
|
|
return impl_->getStreamFromGlobalPool(d, isHighPriority);
|
|
}
|
|
Stream exchangeStream(Stream s) const override {
|
|
return impl_->exchangeStream(s);
|
|
}
|
|
DeviceIndex deviceCount() const noexcept override {
|
|
return impl_->deviceCount();
|
|
}
|
|
|
|
// Event functions
|
|
void record(
|
|
void** event,
|
|
const Stream& stream,
|
|
const DeviceIndex device_index,
|
|
const EventFlag flag) const override {
|
|
impl_->record(event, stream, device_index, flag);
|
|
}
|
|
void block(void* event, const Stream& stream) const override {
|
|
impl_->block(event, stream);
|
|
}
|
|
bool queryEvent(void* event) const override {
|
|
return impl_->queryEvent(event);
|
|
}
|
|
void destroyEvent(void* event, const DeviceIndex device_index)
|
|
const noexcept override {
|
|
impl_->destroyEvent(event, device_index);
|
|
}
|
|
|
|
bool queryStream(const Stream& stream) const override {
|
|
return impl_->queryStream(stream);
|
|
}
|
|
void synchronizeStream(const Stream& stream) const override {
|
|
impl_->synchronizeStream(stream);
|
|
}
|
|
|
|
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
|
|
const override {
|
|
impl_->recordDataPtrOnStream(data_ptr, stream);
|
|
}
|
|
|
|
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
|
|
const override {
|
|
return impl_->elapsedTime(event1, event2, device_index);
|
|
}
|
|
|
|
void synchronizeEvent(void* event) const override {
|
|
return impl_->synchronizeEvent(event);
|
|
}
|
|
|
|
void synchronizeDevice(const DeviceIndex device_index) const override {
|
|
return impl_->synchronizeDevice(device_index);
|
|
}
|
|
|
|
private:
|
|
const DeviceGuardImplInterface* impl_ = nullptr;
|
|
};
|
|
|
|
} // namespace c10::impl
|