mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[utils] add try_import method for importing optional modules (#145528)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145528 Approved by: https://github.com/albanD
This commit is contained in:
parent
f3304571fc
commit
f2ad2cdf1c
2 changed files with 40 additions and 0 deletions
|
|
@ -19,6 +19,7 @@ import torch.cuda
|
|||
import torch.nn as nn
|
||||
import torch.utils.cpp_extension
|
||||
import torch.utils.data
|
||||
from torch._utils import try_import
|
||||
from torch.autograd._functions.utils import check_onnx_broadcast
|
||||
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
|
|
@ -1163,5 +1164,23 @@ def f(x):
|
|||
self.assertIn("test_captured_traceback_format_all", "".join(rs[0]))
|
||||
|
||||
|
||||
class TestTryImport(TestCase):
|
||||
def test_import_imported(self):
|
||||
self.assertIn("os", sys.modules)
|
||||
os_module = try_import("os")
|
||||
self.assertIs(os_module, os)
|
||||
|
||||
def test_import_existing(self):
|
||||
self.assertNotIn("imaplib", sys.modules)
|
||||
imaplib_module = try_import("imaplib")
|
||||
self.assertIsNotNone(imaplib_module)
|
||||
self.assertFalse(hasattr(imaplib_module, "not_attribute"))
|
||||
self.assertTrue(hasattr(imaplib_module, "IMAP4"))
|
||||
|
||||
def test_import_missing(self):
|
||||
missing_module = try_import("missing_module")
|
||||
self.assertIsNone(missing_module)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import copyreg
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Generic, Optional, TYPE_CHECKING
|
||||
from typing_extensions import deprecated, ParamSpec
|
||||
|
||||
|
|
@ -1012,6 +1014,25 @@ class CallbackRegistry(Generic[P]):
|
|||
)
|
||||
|
||||
|
||||
def try_import(module_name: str) -> Optional[ModuleType]:
|
||||
# Implementation based on
|
||||
# https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported
|
||||
if (module := sys.modules.get(module_name, None)) is not None:
|
||||
return module
|
||||
|
||||
if (spec := importlib.util.find_spec(module_name)) is not None:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
# https://docs.python.org/3/library/importlib.html#importlib.machinery.ModuleSpec.loader
|
||||
# "The finder should always set this attribute"
|
||||
assert spec.loader is not None, "The loader attribute should always be set"
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
|
||||
# for use in the weights_only Unpickler.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue