diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index ca7ead3bbfc..f59f83b08aa 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -43,9 +43,19 @@ class TORCH_API Context { if (device_type == at::kCPU) { return at::detail::getDefaultCPUGenerator(); + } else if (device_type == at::kCUDA) { + return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index()); + } else if (device_type == at::kMPS) { + return at::detail::getMPSHooks().getDefaultMPSGenerator(); + } else if (device_type == at::kXPU) { + return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index()); + } else if (device_type == at::kIPU) { + return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index()); + } else if (device_type == at::kPrivateUse1) { + return at::detail::getPrivateUse1Hooks().getDefaultGenerator( + device.index()); } else { - return getAcceleratorHooksInterface(device_type) - .getDefaultGenerator(device.index()); + AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); } } diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index a6439f7c3e5..d5b4c3ae62b 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -102,7 +102,7 @@ void CUDAHooks::init() const { #endif } -const Generator& CUDAHooks::getDefaultGenerator(DeviceIndex device_index) const { +const Generator& CUDAHooks::getDefaultCUDAGenerator(DeviceIndex device_index) const { return at::cuda::detail::getDefaultCUDAGenerator(device_index); } diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index ea190c9e1a5..2dbc336778c 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -21,8 +21,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { void init() const override; Device getDeviceFromPtr(void* data) const override; bool isPinnedPtr(const void* data) const override; - const Generator& getDefaultGenerator( - DeviceIndex device_index = -1) const override; + const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override; bool hasCUDA() const override; bool hasMAGMA() const override; bool hasCuDNN() const override; diff --git a/aten/src/ATen/detail/AcceleratorHooksInterface.h b/aten/src/ATen/detail/AcceleratorHooksInterface.h index d5da8592874..4eab4d24f71 100644 --- a/aten/src/ATen/detail/AcceleratorHooksInterface.h +++ b/aten/src/ATen/detail/AcceleratorHooksInterface.h @@ -1,13 +1,9 @@ #pragma once -#include - -#include #include #include - +#include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") - namespace at { // AcceleratorHooksInterface is a shared interface provided by all @@ -62,18 +58,7 @@ struct TORCH_API AcceleratorHooksInterface { virtual Device getDeviceFromPtr(void* data) const { TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()"); } - - virtual const Generator& getDefaultGenerator( - C10_UNUSED DeviceIndex device_index = -1) const { - TORCH_CHECK(false, "Backend doesn`t support getDefaultGenerator()"); - } - - virtual Generator getNewGenerator( - C10_UNUSED DeviceIndex device_index = -1) const { - TORCH_CHECK(false, "Backend doesn`t support getNewGenerator()"); - } }; } // namespace at - C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index dc7bf51ad72..144643e5297 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -6,13 +6,16 @@ #include -// NB: Class must live in `at` due to limitations of Registry.h. +// Forward-declares at::Generator and at::cuda::NVRTC namespace at { - -// Forward-declares at::cuda::NVRTC +struct Generator; namespace cuda { struct NVRTC; } // namespace cuda +} // namespace at + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { #ifdef _MSC_VER constexpr const char* CUDA_HELP = @@ -66,8 +69,8 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP); } - const Generator& getDefaultGenerator( - [[maybe_unused]] DeviceIndex device_index = -1) const override { + virtual const Generator& getDefaultCUDAGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { TORCH_CHECK( false, "Cannot get default CUDA generator without ATen_cuda library. ", diff --git a/aten/src/ATen/detail/HIPHooksInterface.h b/aten/src/ATen/detail/HIPHooksInterface.h index 7cb8dcbd0e8..f852db8d600 100644 --- a/aten/src/ATen/detail/HIPHooksInterface.h +++ b/aten/src/ATen/detail/HIPHooksInterface.h @@ -1,13 +1,19 @@ #pragma once #include +#include #include + #include #include #include +namespace at { +class Context; +} + // NB: Class must live in `at` due to limitations of Registry.h. namespace at { @@ -24,9 +30,8 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface { TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library."); } - const Generator& getDefaultGenerator( - C10_UNUSED DeviceIndex device_index = -1) const override { - TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library."); + virtual std::unique_ptr initHIPGenerator(Context*) const { + AT_ERROR("Cannot initialize HIP generator without ATen_hip library."); } virtual bool hasHIP() const { @@ -45,6 +50,10 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface { TORCH_CHECK(false, "Pinned memory requires HIP."); } + virtual void registerHIPTypes(Context*) const { + AT_ERROR("Cannot registerHIPTypes() without ATen_hip library."); + } + virtual int getNumGPUs() const { return 0; } diff --git a/aten/src/ATen/detail/IPUHooksInterface.h b/aten/src/ATen/detail/IPUHooksInterface.h index 6d92c9dbef2..20dbb703d57 100644 --- a/aten/src/ATen/detail/IPUHooksInterface.h +++ b/aten/src/ATen/detail/IPUHooksInterface.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -8,7 +9,7 @@ namespace at { -struct TORCH_API IPUHooksInterface : AcceleratorHooksInterface { +struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface { ~IPUHooksInterface() override = default; void init() const override { @@ -20,14 +21,16 @@ struct TORCH_API IPUHooksInterface : AcceleratorHooksInterface { return false; } - const Generator& getDefaultGenerator( - C10_UNUSED DeviceIndex device_index = -1) const override { - TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + virtual const Generator& getDefaultIPUGenerator( + DeviceIndex device_index [[maybe_unused]] = -1) const { + AT_ERROR( + "Cannot get the default IPU generator: the IPU backend is not " + "available."); } - Generator getNewGenerator( - DeviceIndex device_index [[maybe_unused]] = -1) const override { - TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + virtual Generator newIPUGenerator(DeviceIndex device_index [[maybe_unused]] = -1) const { + AT_ERROR( + "Cannot create a new IPU generator: the IPU backend is not available."); } }; diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index f2235c3faf1..e3f8d3132bb 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -2,9 +2,9 @@ #pragma once -#include - #include +#include +#include #include #include @@ -31,8 +31,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const { FAIL_MPSHOOKS_FUNC(__func__); } - const Generator& getDefaultGenerator( - C10_UNUSED DeviceIndex device_index = -1) const override { + virtual const Generator& getDefaultMPSGenerator() const { FAIL_MPSHOOKS_FUNC(__func__); } virtual Allocator* getMPSDeviceAllocator() const { diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.h b/aten/src/ATen/detail/PrivateUse1HooksInterface.h index 17927046d2e..3820c960dfe 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.h +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.h @@ -1,20 +1,18 @@ #pragma once +#include #include #include #include #include #include - C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") - namespace at { struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { ~PrivateUse1HooksInterface() override = default; - - const at::Generator& getDefaultGenerator( - c10::DeviceIndex device_index) const override { + virtual const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const { TORCH_CHECK_NOT_IMPLEMENTED( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`."); @@ -26,17 +24,17 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`."); } - bool isPinnedPtr(const void* data) const override { + virtual bool isPinnedPtr(const void* data) const override { return false; } - Allocator* getPinnedMemoryAllocator() const override { + virtual Allocator* getPinnedMemoryAllocator() const override { TORCH_CHECK( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`."); } - bool hasPrimaryContext(DeviceIndex device_index) const override { + virtual bool hasPrimaryContext(DeviceIndex device_index) const override { TORCH_CHECK_NOT_IMPLEMENTED( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`."); diff --git a/aten/src/ATen/detail/XPUHooksInterface.h b/aten/src/ATen/detail/XPUHooksInterface.h index 3786b5444d6..8cb5497e62c 100644 --- a/aten/src/ATen/detail/XPUHooksInterface.h +++ b/aten/src/ATen/detail/XPUHooksInterface.h @@ -4,6 +4,7 @@ #include #include +#include #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") @@ -31,15 +32,15 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{ TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library."); } - const Generator& getDefaultGenerator( - [[maybe_unused]] DeviceIndex device_index = -1) const override { - TORCH_CHECK( - false, "Cannot get default XPU generator without ATen_xpu library."); + virtual Generator getXPUGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { + TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library."); } - Generator getNewGenerator( - [[maybe_unused]] DeviceIndex device_index = -1) const override { - TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library."); + virtual const Generator& getDefaultXPUGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { + TORCH_CHECK( + false, "Cannot get default XPU generator without ATen_xpu library."); } virtual DeviceIndex getNumGPUs() const { diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index a5c64d72d4d..20662be4369 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -19,8 +19,7 @@ struct MPSHooks : public at::MPSHooksInterface { bool isOnMacOSorNewer(unsigned major, unsigned minor) const override; // MPSGeneratorImpl interface - const Generator& getDefaultGenerator( - DeviceIndex device_index = -1) const override; + const Generator& getDefaultMPSGenerator() const override; // MPSStream interface void deviceSynchronize() const override; diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index bacf1a2bef8..983bb516a31 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -59,7 +59,7 @@ Allocator* MPSHooks::getMPSDeviceAllocator() const { return at::mps::GetMPSAllocator(); } -const Generator& MPSHooks::getDefaultGenerator([[maybe_unused]] DeviceIndex device_index) const { +const Generator& MPSHooks::getDefaultMPSGenerator() const { return at::mps::detail::getDefaultMPSGenerator(); } diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 427efb391d8..744649e121f 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -2363,7 +2363,8 @@ DropoutState& get_dropout_state( std::unique_lock lock{state_cache_mut}; auto& state = dropout_state_cache.at(device); if (train && dropout_p > 0) { - const auto& gen = at::detail::getCUDAHooks().getDefaultGenerator(device); + const auto& gen = + at::detail::getCUDAHooks().getDefaultCUDAGenerator(device); auto gen_impl = gen.get(); bool reset_rnn_state = gen_impl->reset_rnn_state(); if (!state.buffer.defined() || reset_rnn_state) { diff --git a/aten/src/ATen/xpu/detail/XPUHooks.cpp b/aten/src/ATen/xpu/detail/XPUHooks.cpp index d176d3544ce..05d4482fe97 100644 --- a/aten/src/ATen/xpu/detail/XPUHooks.cpp +++ b/aten/src/ATen/xpu/detail/XPUHooks.cpp @@ -34,12 +34,13 @@ int32_t XPUHooks::getGlobalIdxFromDevice(const at::Device& device) const { #endif } -const Generator& XPUHooks::getDefaultGenerator(DeviceIndex device_index) const { - return at::xpu::detail::getDefaultXPUGenerator(device_index); +Generator XPUHooks::getXPUGenerator(DeviceIndex device_index) const { + return make_generator(device_index); } -Generator XPUHooks::getNewGenerator(DeviceIndex device_index) const { - return make_generator(device_index); +const Generator& XPUHooks::getDefaultXPUGenerator( + DeviceIndex device_index) const { + return at::xpu::detail::getDefaultXPUGenerator(device_index); } Device XPUHooks::getDeviceFromPtr(void* data) const { diff --git a/aten/src/ATen/xpu/detail/XPUHooks.h b/aten/src/ATen/xpu/detail/XPUHooks.h index 4cc25e1fe84..6c1c064bae8 100644 --- a/aten/src/ATen/xpu/detail/XPUHooks.h +++ b/aten/src/ATen/xpu/detail/XPUHooks.h @@ -11,9 +11,9 @@ struct XPUHooks : public at::XPUHooksInterface { bool hasXPU() const override; std::string showConfig() const override; int32_t getGlobalIdxFromDevice(const at::Device& device) const override; - const Generator& getDefaultGenerator( + Generator getXPUGenerator(DeviceIndex device_index = -1) const override; + const Generator& getDefaultXPUGenerator( DeviceIndex device_index = -1) const override; - Generator getNewGenerator(DeviceIndex device_index = -1) const override; Device getDeviceFromPtr(void* data) const override; c10::DeviceIndex getNumGPUs() const override; DeviceIndex current_device() const override; diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index b35a871bbc9..1da0a3229db 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -73,9 +73,9 @@ static PyObject* THPGenerator_pynew( } #endif else if (device.type() == at::kXPU) { - self->cdata = at::detail::getXPUHooks().getNewGenerator(device.index()); + self->cdata = at::detail::getXPUHooks().getXPUGenerator(device.index()); } else if (device.type() == at::kIPU) { - self->cdata = at::detail::getIPUHooks().getNewGenerator(device.index()); + self->cdata = at::detail::getIPUHooks().newIPUGenerator(device.index()); } else if (device.type() == at::kPrivateUse1) { self->cdata = at::GetGeneratorForPrivateuse1(device.index()); } else { diff --git a/torch/csrc/api/src/cuda.cpp b/torch/csrc/api/src/cuda.cpp index 5f708ca42e4..5d7624a9976 100644 --- a/torch/csrc/api/src/cuda.cpp +++ b/torch/csrc/api/src/cuda.cpp @@ -28,7 +28,7 @@ bool cudnn_is_available() { void manual_seed(uint64_t seed) { if (is_available()) { auto index = at::detail::getCUDAHooks().getCurrentDevice(); - auto gen = at::detail::getCUDAHooks().getDefaultGenerator(index); + auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(index); { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); @@ -41,7 +41,8 @@ void manual_seed(uint64_t seed) { void manual_seed_all(uint64_t seed) { auto num_gpu = device_count(); for (const auto i : c10::irange(num_gpu)) { - auto gen = at::detail::getCUDAHooks().getDefaultGenerator(i); + auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator( + static_cast(i)); { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); diff --git a/torch/csrc/api/src/mps.cpp b/torch/csrc/api/src/mps.cpp index b7432c39e3c..7477adb5a82 100644 --- a/torch/csrc/api/src/mps.cpp +++ b/torch/csrc/api/src/mps.cpp @@ -10,7 +10,7 @@ bool is_available() { /// Sets the seed for the MPS's default generator. void manual_seed(uint64_t seed) { if (is_available()) { - auto gen = at::detail::getMPSHooks().getDefaultGenerator(); + auto gen = at::detail::getMPSHooks().getDefaultMPSGenerator(); { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); diff --git a/torch/csrc/api/src/xpu.cpp b/torch/csrc/api/src/xpu.cpp index b64feae2ff0..75837b831d9 100644 --- a/torch/csrc/api/src/xpu.cpp +++ b/torch/csrc/api/src/xpu.cpp @@ -14,7 +14,7 @@ bool is_available() { void manual_seed(uint64_t seed) { if (is_available()) { auto index = at::detail::getXPUHooks().getCurrentDevice(); - auto gen = at::detail::getXPUHooks().getDefaultGenerator(index); + auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(index); { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); @@ -27,7 +27,8 @@ void manual_seed(uint64_t seed) { void manual_seed_all(uint64_t seed) { auto num_gpu = device_count(); for (const auto i : c10::irange(num_gpu)) { - auto gen = at::detail::getXPUHooks().getDefaultGenerator(i); + auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator( + static_cast(i)); { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 6b201409e26..468d0bf2e5d 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -44,7 +44,7 @@ static PyObject* MPSModule_getDefaultMPSGenerator( HANDLE_TH_ERRORS track_bad_mps_fork(); return THPGenerator_initDefaultGenerator( - at::detail::getMPSHooks().getDefaultGenerator()); + at::detail::getMPSHooks().getDefaultMPSGenerator()); END_HANDLE_TH_ERRORS }