From f2ad2cdf1cf5d1436c6ed507428756c93d4d4602 Mon Sep 17 00:00:00 2001 From: Marc Horowitz Date: Thu, 23 Jan 2025 23:33:40 -0800 Subject: [PATCH] [utils] add try_import method for importing optional modules (#145528) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145528 Approved by: https://github.com/albanD --- test/test_utils.py | 19 +++++++++++++++++++ torch/_utils.py | 21 +++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index 0fdd59edce5..5f69ecdfe35 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,6 +19,7 @@ import torch.cuda import torch.nn as nn import torch.utils.cpp_extension import torch.utils.data +from torch._utils import try_import from torch.autograd._functions.utils import check_onnx_broadcast from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings from torch.testing._internal.common_cuda import TEST_MULTIGPU @@ -1163,5 +1164,23 @@ def f(x): self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) +class TestTryImport(TestCase): + def test_import_imported(self): + self.assertIn("os", sys.modules) + os_module = try_import("os") + self.assertIs(os_module, os) + + def test_import_existing(self): + self.assertNotIn("imaplib", sys.modules) + imaplib_module = try_import("imaplib") + self.assertIsNotNone(imaplib_module) + self.assertFalse(hasattr(imaplib_module, "not_attribute")) + self.assertTrue(hasattr(imaplib_module, "IMAP4")) + + def test_import_missing(self): + missing_module = try_import("missing_module") + self.assertIsNone(missing_module) + + if __name__ == "__main__": run_tests() diff --git a/torch/_utils.py b/torch/_utils.py index 5b1b9b03cb8..7c645435f87 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,11 +1,13 @@ # mypy: allow-untyped-defs import copyreg import functools +import importlib import logging import sys import traceback import warnings from collections import defaultdict +from types import ModuleType from typing import Any, Callable, Generic, Optional, TYPE_CHECKING from typing_extensions import deprecated, ParamSpec @@ -1012,6 +1014,25 @@ class CallbackRegistry(Generic[P]): ) +def try_import(module_name: str) -> Optional[ModuleType]: + # Implementation based on + # https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported + if (module := sys.modules.get(module_name, None)) is not None: + return module + + if (spec := importlib.util.find_spec(module_name)) is not None: + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + # https://docs.python.org/3/library/importlib.html#importlib.machinery.ModuleSpec.loader + # "The finder should always set this attribute" + assert spec.loader is not None, "The loader attribute should always be set" + spec.loader.exec_module(module) + return module + + return None + + # IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py # for use in the weights_only Unpickler.