[utils] add try_import method for importing optional modules (#145528)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145528
Approved by: https://github.com/albanD
This commit is contained in:
Marc Horowitz 2025-01-23 23:33:40 -08:00 committed by PyTorch MergeBot
parent f3304571fc
commit f2ad2cdf1c
2 changed files with 40 additions and 0 deletions

View file

@ -19,6 +19,7 @@ import torch.cuda
import torch.nn as nn
import torch.utils.cpp_extension
import torch.utils.data
from torch._utils import try_import
from torch.autograd._functions.utils import check_onnx_broadcast
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
from torch.testing._internal.common_cuda import TEST_MULTIGPU
@ -1163,5 +1164,23 @@ def f(x):
self.assertIn("test_captured_traceback_format_all", "".join(rs[0]))
class TestTryImport(TestCase):
def test_import_imported(self):
self.assertIn("os", sys.modules)
os_module = try_import("os")
self.assertIs(os_module, os)
def test_import_existing(self):
self.assertNotIn("imaplib", sys.modules)
imaplib_module = try_import("imaplib")
self.assertIsNotNone(imaplib_module)
self.assertFalse(hasattr(imaplib_module, "not_attribute"))
self.assertTrue(hasattr(imaplib_module, "IMAP4"))
def test_import_missing(self):
missing_module = try_import("missing_module")
self.assertIsNone(missing_module)
if __name__ == "__main__":
run_tests()

View file

@ -1,11 +1,13 @@
# mypy: allow-untyped-defs
import copyreg
import functools
import importlib
import logging
import sys
import traceback
import warnings
from collections import defaultdict
from types import ModuleType
from typing import Any, Callable, Generic, Optional, TYPE_CHECKING
from typing_extensions import deprecated, ParamSpec
@ -1012,6 +1014,25 @@ class CallbackRegistry(Generic[P]):
)
def try_import(module_name: str) -> Optional[ModuleType]:
# Implementation based on
# https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported
if (module := sys.modules.get(module_name, None)) is not None:
return module
if (spec := importlib.util.find_spec(module_name)) is not None:
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
# https://docs.python.org/3/library/importlib.html#importlib.machinery.ModuleSpec.loader
# "The finder should always set this attribute"
assert spec.loader is not None, "The loader attribute should always be set"
spec.loader.exec_module(module)
return module
return None
# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
# for use in the weights_only Unpickler.