mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[dynamo] Refactor test cross importing (#113242)"
This reverts commit 4309d38f5d.
Reverted https://github.com/pytorch/pytorch/pull/113242 on behalf of https://github.com/huydhn due to Sorry for reverting your stack, but it is failing to list test internally with buck2 ([comment](https://github.com/pytorch/pytorch/pull/113242#issuecomment-1811674395))
This commit is contained in:
parent
6bffde99b0
commit
92e3f45f0e
6 changed files with 49 additions and 49 deletions
|
|
@ -3,19 +3,32 @@ import unittest
|
|||
import warnings
|
||||
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import load_test_module, make_test_cls_with_patches
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
from torch.fx.experimental import _config as fx_config
|
||||
from torch.testing._internal.common_utils import TEST_Z3
|
||||
|
||||
test_aot_autograd = load_test_module(__file__, "dynamo.test_aot_autograd")
|
||||
test_ctx_manager = load_test_module(__file__, "dynamo.test_ctx_manager")
|
||||
test_export = load_test_module(__file__, "dynamo.test_export")
|
||||
test_functions = load_test_module(__file__, "dynamo.test_functions")
|
||||
test_higher_order_ops = load_test_module(__file__, "dynamo.test_higher_order_ops")
|
||||
test_misc = load_test_module(__file__, "dynamo.test_misc")
|
||||
test_modules = load_test_module(__file__, "dynamo.test_modules")
|
||||
test_repros = load_test_module(__file__, "dynamo.test_repros")
|
||||
test_subgraphs = load_test_module(__file__, "dynamo.test_subgraphs")
|
||||
try:
|
||||
from . import (
|
||||
test_aot_autograd,
|
||||
test_ctx_manager,
|
||||
test_export,
|
||||
test_functions,
|
||||
test_higher_order_ops,
|
||||
test_misc,
|
||||
test_modules,
|
||||
test_repros,
|
||||
test_subgraphs,
|
||||
)
|
||||
except ImportError:
|
||||
import test_aot_autograd
|
||||
import test_ctx_manager
|
||||
import test_export
|
||||
import test_functions
|
||||
import test_higher_order_ops
|
||||
import test_misc
|
||||
import test_modules
|
||||
import test_repros
|
||||
import test_subgraphs
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,10 @@ import torch._dynamo.test_case
|
|||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import same
|
||||
|
||||
utils = torch._dynamo.testing.load_test_module(__file__, "dynamo.utils")
|
||||
try:
|
||||
from . import utils
|
||||
except ImportError:
|
||||
import utils
|
||||
|
||||
|
||||
class Pair: # noqa: B903
|
||||
|
|
|
|||
|
|
@ -21,10 +21,10 @@ from torch._dynamo.testing import expectedFailureDynamic, same
|
|||
from torch.nn.modules.lazy import LazyModuleMixin
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
|
||||
test_functions = torch._dynamo.testing.load_test_module(
|
||||
__file__, "dynamo.test_functions"
|
||||
)
|
||||
try:
|
||||
from . import test_functions
|
||||
except ImportError:
|
||||
import test_functions
|
||||
|
||||
|
||||
class BasicModule(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -18,14 +18,15 @@ from torch._dynamo.skipfiles import (
|
|||
LEGACY_MOD_INLINELIST,
|
||||
MOD_INLINELIST,
|
||||
)
|
||||
from torch._dynamo.testing import load_test_module
|
||||
from torch._dynamo.trace_rules import get_torch_obj_rule_map, load_object
|
||||
from torch._dynamo.utils import is_safe_constant, istype
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
|
||||
create_dummy_module_and_function = load_test_module(
|
||||
__file__, "dynamo.utils"
|
||||
).create_dummy_module_and_function
|
||||
try:
|
||||
from .utils import create_dummy_module_and_function
|
||||
except ImportError:
|
||||
from utils import create_dummy_module_and_function
|
||||
|
||||
|
||||
ignored_torch_name_rule_set = {
|
||||
"torch.ExcludeDispatchKeyGuard",
|
||||
|
|
|
|||
|
|
@ -1,14 +1,17 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import _inductor as inductor
|
||||
from torch._dynamo import compiled_autograd
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import load_test_module
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
|
|
@ -374,7 +377,15 @@ class TestCompiledAutograd(TestCase):
|
|||
eager_check()
|
||||
|
||||
|
||||
test_autograd = load_test_module(__file__, "test_autograd")
|
||||
def load_test_module(name):
|
||||
testdir = Path(__file__).absolute().parent.parent
|
||||
with mock.patch("sys.path", [*sys.path, str(testdir)]):
|
||||
return SourceFileLoader(
|
||||
name, str(testdir / f"{name.replace('.', '/')}.py")
|
||||
).load_module()
|
||||
|
||||
|
||||
test_autograd = load_test_module("test_autograd")
|
||||
|
||||
|
||||
class EagerAutogradTests(TestCase):
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ import re
|
|||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
np: Optional[types.ModuleType] = None
|
||||
|
|
@ -374,28 +371,3 @@ def reset_rng_state(use_xla=False):
|
|||
import torch_xla.core.xla_model as xm
|
||||
|
||||
xm.set_rng_state(1337, str(xm.xla_device()))
|
||||
|
||||
|
||||
def load_test_module(from_test_file, wanted_module):
|
||||
"""
|
||||
Import a module from pytorch/test/* in a robust way.
|
||||
|
||||
Args:
|
||||
from_test_file: filename of the test calling this, used to file root path
|
||||
wanted_module: module name to import
|
||||
|
||||
Returns:
|
||||
a Python module
|
||||
"""
|
||||
if wanted_module in sys.modules:
|
||||
return sys.modules[wanted_module]
|
||||
|
||||
testdir = Path(from_test_file).absolute().parent
|
||||
# go up at most 3 directories to find the test root
|
||||
for _ in range(3):
|
||||
target = testdir / f"{wanted_module.replace('.', '/')}.py"
|
||||
if target.exists():
|
||||
with mock.patch("sys.path", [str(testdir), *sys.path]):
|
||||
return SourceFileLoader(wanted_module, str(target)).load_module()
|
||||
testdir = testdir.parent
|
||||
raise ImportError(f"failed to find {wanted_module} from {from_test_file}")
|
||||
|
|
|
|||
Loading…
Reference in a new issue