diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index b2409c84dde..5baad73669a 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -129,6 +129,13 @@ class TORCH_API Context { void lazyInitHIP() { c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); }); } + void lazyInitPrivateUse1() { + c10::call_once(thp_init, [&] { + if (isPrivateUse1HooksRegistered()) { + at::GetPrivateUse1HooksInterface()->initPrivateUse1(); + } + }); + } static const at::cuda::NVRTC& getNVRTC() { return detail::getCUDAHooks().nvrtc(); } @@ -301,6 +308,7 @@ class TORCH_API Context { static bool checkCuBLASConfigDeterministic(); c10::once_flag thc_init; c10::once_flag thh_init; + c10::once_flag thp_init; bool enabled_cudnn = true; bool deterministic_cudnn = false; bool _deterministic_algorithms = false; diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp b/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp index b25fb0f0f89..8c3861c617c 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp @@ -18,4 +18,8 @@ TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface() { return privateuse1_hooks; } +TORCH_API bool isPrivateUse1HooksRegistered() { + return privateuse1_hooks != nullptr; +} + } diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.h b/aten/src/ATen/detail/PrivateUse1HooksInterface.h index 91b197daeff..142e812d283 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.h +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.h @@ -18,6 +18,8 @@ struct TORCH_API PrivateUse1HooksInterface { false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`."); } + + virtual void initPrivateUse1() const {} }; struct TORCH_API PrivateUse1HooksArgs {}; @@ -26,4 +28,6 @@ TORCH_API void RegisterPrivateUse1HooksInterface(at::PrivateUse1HooksInterface* TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface(); +TORCH_API bool isPrivateUse1HooksRegistered(); + } diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index b9e0c43283a..bccc7ecff05 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -978,6 +978,9 @@ static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwarg if (device && device->is_cuda()) { torch::utils::cuda_lazy_init(); } + if (device && device->is_privateuseone()) { + at::globalContext().lazyInitPrivateUse1(); + } if (!device && !scalarType && !copy && !opt_memory_format.has_value()) { Py_INCREF(self); return self; @@ -1059,6 +1062,9 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa if (device.is_cuda()) { torch::utils::cuda_lazy_init(); } + if (device.is_privateuseone()) { + at::globalContext().lazyInitPrivateUse1(); + } return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format)); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 317eac83a45..e19e0ae5a10 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -332,6 +332,7 @@ c10::intrusive_ptr make_storage_impl( } else if (device.type() == at::DeviceType::Meta) { allocator = c10::GetAllocator(device.type()); } else if (device.type() == at::DeviceType::PrivateUse1) { + at::globalContext().lazyInitPrivateUse1(); allocator = c10::GetAllocator(device.type()); } else { // NOLINTEND(bugprone-branch-clone)