Enables private_use_one lazy_init by PrivateUse1HooksInterface (#115067)

Fixes https://github.com/pytorch/pytorch/issues/112369

In my last pr:https://github.com/pytorch/pytorch/pull/113343, I want to implement lazy_init for other device through `REGISTER_LAZY_INIT `. But this might be too big of a change.

Recently, my team found that `torch.load` without `lazy_init ` will also results in the same error.
bbd5b935e4/torch/csrc/Storage.cpp (L319-L321)
bbd5b935e4/torch/csrc/Storage.cpp (L334-L335)

So, I want to use `PrivateUse1HooksInterface` to implement lazy_init for `PrivateUse1`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115067
Approved by: https://github.com/ezyang
This commit is contained in:
feifan 2024-01-09 20:12:05 +00:00 committed by PyTorch MergeBot
parent ab1ac43752
commit 29ae4f22bf
5 changed files with 23 additions and 0 deletions

View file

@ -129,6 +129,13 @@ class TORCH_API Context {
void lazyInitHIP() {
c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
}
void lazyInitPrivateUse1() {
c10::call_once(thp_init, [&] {
if (isPrivateUse1HooksRegistered()) {
at::GetPrivateUse1HooksInterface()->initPrivateUse1();
}
});
}
static const at::cuda::NVRTC& getNVRTC() {
return detail::getCUDAHooks().nvrtc();
}
@ -301,6 +308,7 @@ class TORCH_API Context {
static bool checkCuBLASConfigDeterministic();
c10::once_flag thc_init;
c10::once_flag thh_init;
c10::once_flag thp_init;
bool enabled_cudnn = true;
bool deterministic_cudnn = false;
bool _deterministic_algorithms = false;

View file

@ -18,4 +18,8 @@ TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface() {
return privateuse1_hooks;
}
TORCH_API bool isPrivateUse1HooksRegistered() {
return privateuse1_hooks != nullptr;
}
}

View file

@ -18,6 +18,8 @@ struct TORCH_API PrivateUse1HooksInterface {
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");
}
virtual void initPrivateUse1() const {}
};
struct TORCH_API PrivateUse1HooksArgs {};
@ -26,4 +28,6 @@ TORCH_API void RegisterPrivateUse1HooksInterface(at::PrivateUse1HooksInterface*
TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface();
TORCH_API bool isPrivateUse1HooksRegistered();
}

View file

@ -978,6 +978,9 @@ static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwarg
if (device && device->is_cuda()) {
torch::utils::cuda_lazy_init();
}
if (device && device->is_privateuseone()) {
at::globalContext().lazyInitPrivateUse1();
}
if (!device && !scalarType && !copy && !opt_memory_format.has_value()) {
Py_INCREF(self);
return self;
@ -1059,6 +1062,9 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
if (device.is_cuda()) {
torch::utils::cuda_lazy_init();
}
if (device.is_privateuseone()) {
at::globalContext().lazyInitPrivateUse1();
}
return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format));
END_HANDLE_TH_ERRORS
}

View file

@ -332,6 +332,7 @@ c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
} else if (device.type() == at::DeviceType::Meta) {
allocator = c10::GetAllocator(device.type());
} else if (device.type() == at::DeviceType::PrivateUse1) {
at::globalContext().lazyInitPrivateUse1();
allocator = c10::GetAllocator(device.type());
} else {
// NOLINTEND(bugprone-branch-clone)