Move get accelerator to use build time flags when possible (#146098)

This PR does two main things (they are in a single PR to show how the newly added APIs are used).

- Add isBuilt and isAvailable APIs to the AcceleratorHook interface. See inline doc for their exact semantic
- Use the newly added isBuilt for accelerator check to ensure it does not poison fork

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146098
Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/EikanWang

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
albanD 2025-02-04 18:23:19 +00:00 committed by PyTorch MergeBot
parent 23fffb54d5
commit 157d81c201
7 changed files with 78 additions and 28 deletions

View file

@ -5,38 +5,53 @@
namespace at::accelerator {
std::optional<c10::DeviceType> getAccelerator(bool checked) {
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
if (at::has##device_name()) { \
device_type = k##device_name; \
TORCH_CHECK( \
!is_accelerator_detected, \
"Cannot have ", \
device_type.value(), \
" with other accelerators."); \
is_accelerator_detected = true; \
}
// 1. Check PrivateUse1 backends
// We explicitly allow PrivateUse1 and another device at the same time as we
// use this for testing. Whenever a PrivateUse1 device is registered, use it
// first.
// Note that this check is only for hook registration and thus is NOT initializing
// the device or poisoning fork.
if (is_privateuse1_backend_registered()) {
// We explicitly allow PrivateUse1 and another device at the same time as we
// use this for testing. Whenever a PrivateUse1 device is registered, use it
// first.
return kPrivateUse1;
}
// 2. Check runtime backends
// This state is temporary, these runtime checks should be moved to compile-time
// once they provide the new isBuilt API and we are sure they're never in the
// same binary as another accelerator.
#define DETECT_RUNTIME_ACCELERATOR(device_name) \
if (at::has##device_name()) { \
return k##device_name; \
}
DETECT_RUNTIME_ACCELERATOR(MTIA)
DETECT_RUNTIME_ACCELERATOR(HPU)
#undef DETECT_RUNTIME_ACCELERATOR
// 2. Check compile-time backends
std::optional<c10::DeviceType> device_type = std::nullopt;
bool is_accelerator_detected = false;
DETECT_AND_ASSIGN_ACCELERATOR(CUDA)
DETECT_AND_ASSIGN_ACCELERATOR(MTIA)
DETECT_AND_ASSIGN_ACCELERATOR(XPU)
DETECT_AND_ASSIGN_ACCELERATOR(HIP)
DETECT_AND_ASSIGN_ACCELERATOR(MPS)
DETECT_AND_ASSIGN_ACCELERATOR(HPU)
#define DETECT_AND_ASSIGN_ACCELERATOR_COMP(device_name) \
if (at::detail::get##device_name##Hooks().isBuilt()) { \
TORCH_CHECK( \
!device_type.has_value(), \
"Cannot have both " #device_name " and ", \
device_type.value(), "."); \
device_type = k##device_name; \
}
DETECT_AND_ASSIGN_ACCELERATOR_COMP(CUDA)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(XPU)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HIP)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(MPS)
if (checked) {
TORCH_CHECK(
device_type, "Cannot access accelerator device when none is available.")
}
return device_type;
#undef DETECT_AND_ASSIGN_ACCELERATOR
#undef DETECT_AND_ASSIGN_ACCELERATOR_COMP
}
bool isAccelerator(c10::DeviceType device_type) {

View file

@ -33,6 +33,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasROCM() const override;
const at::cuda::NVRTC& nvrtc() const override;
DeviceIndex current_device() const override;
bool isBuilt() const override {return true;}
bool isAvailable() const override {return hasCUDA();}
bool hasPrimaryContext(DeviceIndex device_index) const override;
Allocator* getCUDADeviceAllocator() const override;
Allocator* getPinnedMemoryAllocator() const override;

View file

@ -20,6 +20,23 @@ struct TORCH_API AcceleratorHooksInterface {
// squelch -Werror=non-virtual-dtor
virtual ~AcceleratorHooksInterface() = default;
// Whether this backend was enabled at compilation time.
// This function should NEVER throw.
virtual bool isBuilt() const {
return false;
}
// Whether this backend can be used at runtime, meaning it was built,
// its runtime dependencies are available (driver) and at least one
// supported device can be used.
// This function should NEVER throw. This function should NOT initialize the context
// on any device (result of hasPrimaryContext below should not change).
// While it is acceptable for this function to poison fork, it is
// recommended to avoid doing so whenever possible.
virtual bool isAvailable() const {
return false;
}
// Whether the device at device_index is fully initialized or not.
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;

View file

@ -54,7 +54,12 @@ struct MPSHooks : public at::MPSHooksInterface {
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
const override;
// Compatibility with Accelerator API
bool isBuilt() const override {
return true;
}
bool isAvailable() const override {
return hasMPS();
}
bool hasPrimaryContext(DeviceIndex device_index) const override {
// When MPS is available, it is always in use for the one device.
return true;

View file

@ -85,9 +85,14 @@ bool XPUHooks::isPinnedPtr(const void* data) const {
sycl::get_pointer_type(data, c10::xpu::get_device_context());
}
bool XPUHooks::isAvailable() const {
return at::xpu::is_available();
}
bool XPUHooks::hasPrimaryContext(DeviceIndex device_index) const {
// The default context is utilized for each device. So it always returns true.
return true;
// The default context is utilized for each device.
// So it always returns true if a device is available.
return isAvailable();
}
DeviceIndex XPUHooks::deviceCount() const {

View file

@ -19,6 +19,11 @@ struct XPUHooks : public at::XPUHooksInterface {
DeviceIndex current_device() const override;
void deviceSynchronize(DeviceIndex device_index) const override;
Allocator* getPinnedMemoryAllocator() const override;
bool isBuilt() const override {
return true;
}
bool isAvailable() const override;
bool isPinnedPtr(const void* data) const override;
bool hasPrimaryContext(DeviceIndex device_index) const override;
DeviceIndex deviceCount() const override;

View file

@ -3534,8 +3534,9 @@ def fork_and_check_is_pinned():
def worker(conn):
try:
x = torch.randn(10)
x.is_pinned(device="cuda")
x = torch.ones(10, device="cuda")[0].item()
x.is_pinned()
dev = torch.accelerator.current_accelerator()
x = torch.ones(10, device=dev)[0].item()
conn.send(x)
except Exception as e:
conn.send(str(e))
@ -3555,7 +3556,7 @@ def fork_and_check_is_pinned():
x = torch.randn(10)
# check that is_pinned won't poison future fork
x.is_pinned(device="cuda")
x.is_pinned()
ret = fork_and_check_is_pinned()
print(ret)