From 6a0138fcc1a982efe9b0ec54016728d7f4d63b94 Mon Sep 17 00:00:00 2001 From: Mwiza Kunda Date: Fri, 31 Jan 2025 19:27:38 +0000 Subject: [PATCH] Torch device backend autoload fix (#145611) This causes an import failure if an external backend imports a module that uses `torch._as_tensor_fullprec` when it is being loaded. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/145611 Approved by: https://github.com/albanD --- torch/__init__.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index 67f8dd1ef55..eea6e5c0891 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2779,10 +2779,6 @@ def _is_device_backend_autoload_enabled() -> builtins.bool: return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1" -if _is_device_backend_autoload_enabled(): - _import_device_backends() - - def _as_tensor_fullprec(t): """ Like torch.as_tensor, but when given Python data types it will keep @@ -2795,3 +2791,10 @@ def _as_tensor_fullprec(t): return torch.as_tensor(t, dtype=torch.int64) else: return torch.as_tensor(t) + + +# `_import_device_backends` should be kept at the end to ensure +# all the other functions in this module that may be accessed by +# an autoloaded backend are defined +if _is_device_backend_autoload_enabled(): + _import_device_backends()