mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Move get accelerator to use build time flags when possible (#146098)
This PR does two main things (they are in a single PR to show how the newly added APIs are used). - Add isBuilt and isAvailable APIs to the AcceleratorHook interface. See inline doc for their exact semantic - Use the newly added isBuilt for accelerator check to ensure it does not poison fork Pull Request resolved: https://github.com/pytorch/pytorch/pull/146098 Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/EikanWang Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
parent
23fffb54d5
commit
157d81c201
7 changed files with 78 additions and 28 deletions
|
|
@ -5,38 +5,53 @@
|
|||
namespace at::accelerator {
|
||||
|
||||
std::optional<c10::DeviceType> getAccelerator(bool checked) {
|
||||
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
|
||||
if (at::has##device_name()) { \
|
||||
device_type = k##device_name; \
|
||||
TORCH_CHECK( \
|
||||
!is_accelerator_detected, \
|
||||
"Cannot have ", \
|
||||
device_type.value(), \
|
||||
" with other accelerators."); \
|
||||
is_accelerator_detected = true; \
|
||||
}
|
||||
|
||||
// 1. Check PrivateUse1 backends
|
||||
// We explicitly allow PrivateUse1 and another device at the same time as we
|
||||
// use this for testing. Whenever a PrivateUse1 device is registered, use it
|
||||
// first.
|
||||
// Note that this check is only for hook registration and thus is NOT initializing
|
||||
// the device or poisoning fork.
|
||||
if (is_privateuse1_backend_registered()) {
|
||||
// We explicitly allow PrivateUse1 and another device at the same time as we
|
||||
// use this for testing. Whenever a PrivateUse1 device is registered, use it
|
||||
// first.
|
||||
return kPrivateUse1;
|
||||
}
|
||||
|
||||
// 2. Check runtime backends
|
||||
// This state is temporary, these runtime checks should be moved to compile-time
|
||||
// once they provide the new isBuilt API and we are sure they're never in the
|
||||
// same binary as another accelerator.
|
||||
#define DETECT_RUNTIME_ACCELERATOR(device_name) \
|
||||
if (at::has##device_name()) { \
|
||||
return k##device_name; \
|
||||
}
|
||||
|
||||
DETECT_RUNTIME_ACCELERATOR(MTIA)
|
||||
DETECT_RUNTIME_ACCELERATOR(HPU)
|
||||
|
||||
#undef DETECT_RUNTIME_ACCELERATOR
|
||||
|
||||
// 2. Check compile-time backends
|
||||
std::optional<c10::DeviceType> device_type = std::nullopt;
|
||||
bool is_accelerator_detected = false;
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(CUDA)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(MTIA)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(XPU)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(HIP)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(MPS)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(HPU)
|
||||
|
||||
#define DETECT_AND_ASSIGN_ACCELERATOR_COMP(device_name) \
|
||||
if (at::detail::get##device_name##Hooks().isBuilt()) { \
|
||||
TORCH_CHECK( \
|
||||
!device_type.has_value(), \
|
||||
"Cannot have both " #device_name " and ", \
|
||||
device_type.value(), "."); \
|
||||
device_type = k##device_name; \
|
||||
}
|
||||
|
||||
DETECT_AND_ASSIGN_ACCELERATOR_COMP(CUDA)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR_COMP(XPU)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HIP)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR_COMP(MPS)
|
||||
if (checked) {
|
||||
TORCH_CHECK(
|
||||
device_type, "Cannot access accelerator device when none is available.")
|
||||
}
|
||||
return device_type;
|
||||
|
||||
#undef DETECT_AND_ASSIGN_ACCELERATOR
|
||||
#undef DETECT_AND_ASSIGN_ACCELERATOR_COMP
|
||||
}
|
||||
|
||||
bool isAccelerator(c10::DeviceType device_type) {
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
|||
bool hasROCM() const override;
|
||||
const at::cuda::NVRTC& nvrtc() const override;
|
||||
DeviceIndex current_device() const override;
|
||||
bool isBuilt() const override {return true;}
|
||||
bool isAvailable() const override {return hasCUDA();}
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override;
|
||||
Allocator* getCUDADeviceAllocator() const override;
|
||||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,23 @@ struct TORCH_API AcceleratorHooksInterface {
|
|||
// squelch -Werror=non-virtual-dtor
|
||||
virtual ~AcceleratorHooksInterface() = default;
|
||||
|
||||
// Whether this backend was enabled at compilation time.
|
||||
// This function should NEVER throw.
|
||||
virtual bool isBuilt() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Whether this backend can be used at runtime, meaning it was built,
|
||||
// its runtime dependencies are available (driver) and at least one
|
||||
// supported device can be used.
|
||||
// This function should NEVER throw. This function should NOT initialize the context
|
||||
// on any device (result of hasPrimaryContext below should not change).
|
||||
// While it is acceptable for this function to poison fork, it is
|
||||
// recommended to avoid doing so whenever possible.
|
||||
virtual bool isAvailable() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Whether the device at device_index is fully initialized or not.
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,12 @@ struct MPSHooks : public at::MPSHooksInterface {
|
|||
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
|
||||
const override;
|
||||
|
||||
// Compatibility with Accelerator API
|
||||
bool isBuilt() const override {
|
||||
return true;
|
||||
}
|
||||
bool isAvailable() const override {
|
||||
return hasMPS();
|
||||
}
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
// When MPS is available, it is always in use for the one device.
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -85,9 +85,14 @@ bool XPUHooks::isPinnedPtr(const void* data) const {
|
|||
sycl::get_pointer_type(data, c10::xpu::get_device_context());
|
||||
}
|
||||
|
||||
bool XPUHooks::isAvailable() const {
|
||||
return at::xpu::is_available();
|
||||
}
|
||||
|
||||
bool XPUHooks::hasPrimaryContext(DeviceIndex device_index) const {
|
||||
// The default context is utilized for each device. So it always returns true.
|
||||
return true;
|
||||
// The default context is utilized for each device.
|
||||
// So it always returns true if a device is available.
|
||||
return isAvailable();
|
||||
}
|
||||
|
||||
DeviceIndex XPUHooks::deviceCount() const {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,11 @@ struct XPUHooks : public at::XPUHooksInterface {
|
|||
DeviceIndex current_device() const override;
|
||||
void deviceSynchronize(DeviceIndex device_index) const override;
|
||||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
|
||||
bool isBuilt() const override {
|
||||
return true;
|
||||
}
|
||||
bool isAvailable() const override;
|
||||
bool isPinnedPtr(const void* data) const override;
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override;
|
||||
DeviceIndex deviceCount() const override;
|
||||
|
|
|
|||
|
|
@ -3534,8 +3534,9 @@ def fork_and_check_is_pinned():
|
|||
def worker(conn):
|
||||
try:
|
||||
x = torch.randn(10)
|
||||
x.is_pinned(device="cuda")
|
||||
x = torch.ones(10, device="cuda")[0].item()
|
||||
x.is_pinned()
|
||||
dev = torch.accelerator.current_accelerator()
|
||||
x = torch.ones(10, device=dev)[0].item()
|
||||
conn.send(x)
|
||||
except Exception as e:
|
||||
conn.send(str(e))
|
||||
|
|
@ -3555,7 +3556,7 @@ def fork_and_check_is_pinned():
|
|||
|
||||
x = torch.randn(10)
|
||||
# check that is_pinned won't poison future fork
|
||||
x.is_pinned(device="cuda")
|
||||
x.is_pinned()
|
||||
ret = fork_and_check_is_pinned()
|
||||
print(ret)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue