From 1e57154af3813a7c3b3d39be53027f631df3f76d Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 29 Jan 2025 10:38:28 -0800 Subject: [PATCH] Require that all HOPs be imported at `import torch` time (#145939) E.g. torch.ops.higher_order.cond does not exist until it is imported, which is bad if it shows up in an FX graph or is used in some code somewhere. This PR also makes some more HOPs get imported at `import torch` time. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/145939 Approved by: https://github.com/ydwu4 ghstack dependencies: #145938 --- test/test_hop_infra.py | 58 +++++++++++++++++++++ torch/_higher_order_ops/__init__.py | 30 +++++++++++ torch/_higher_order_ops/associative_scan.py | 3 +- torch/_higher_order_ops/wrap.py | 9 +++- torch/testing/_internal/hop_db.py | 10 ++++ 5 files changed, 108 insertions(+), 2 deletions(-) diff --git a/test/test_hop_infra.py b/test/test_hop_infra.py index 99a10b673d0..291c86f330a 100644 --- a/test/test_hop_infra.py +++ b/test/test_hop_infra.py @@ -1,4 +1,8 @@ # Owner(s): ["module: higher order operators"] +import importlib +import pkgutil + +import torch from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase from torch.testing._internal.hop_db import ( FIXME_hop_that_doesnt_have_opinfo_test_allowlist, @@ -6,9 +10,21 @@ from torch.testing._internal.hop_db import ( ) +def do_imports(): + for mod in pkgutil.walk_packages( + torch._higher_order_ops.__path__, "torch._higher_order_ops." + ): + modname = mod.name + importlib.import_module(modname) + + +do_imports() + + @skipIfTorchDynamo("not applicable") class TestHOPInfra(TestCase): def test_all_hops_have_opinfo(self): + """All HOPs should have an OpInfo in torch/testing/_internal/hop_db.py""" from torch._ops import _higher_order_ops hops_that_have_op_info = {k.name for k in hop_db} @@ -28,6 +44,48 @@ class TestHOPInfra(TestCase): f"Missing hop_db OpInfo entries for {missing_ops}, please add them to torch/testing/_internal/hop_db.py", ) + def test_all_hops_are_imported(self): + """All HOPs should be listed in torch._higher_order_ops.__all__ + + Some constraints (see test_testing.py::TestImports) + - Sympy must be lazily imported + - Dynamo must be lazily imported + """ + imported_hops = torch._higher_order_ops.__all__ + registered_hops = torch._ops._higher_order_ops.keys() + + # Please don't add anything here. + # We want to ensure that all HOPs are imported at "import torch" time. + # It is bad if someone tries to access torch.ops.higher_order.cond + # and it doesn't exist (this may happen if your HOP isn't imported at + # "import torch" time). + FIXME_ALLOWLIST = { + "autograd_function_apply", + "run_with_rng_state", + "map_impl", + "_export_tracepoint", + "run_and_save_rng_state", + "map", + "custom_function_call", + "trace_wrapped", + "triton_kernel_wrapper_functional", + "triton_kernel_wrapper_mutation", + "wrap", # Really weird failure -- importing this causes Dynamo to choke on checkpoint + } + not_imported_hops = registered_hops - imported_hops + not_imported_hops = not_imported_hops - FIXME_ALLOWLIST + self.assertEqual( + not_imported_hops, + set(), + msg="All HOPs must be listed under torch/_higher_order_ops/__init__.py's __all__.", + ) + + def test_imports_from_all_work(self): + """All APIs listed in torch._higher_order_ops.__all__ must be importable""" + stuff = torch._higher_order_ops.__all__ + for attr in stuff: + getattr(torch._higher_order_ops, attr) + if __name__ == "__main__": run_tests() diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index b2e118893cd..9c0d6c1713c 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -1,4 +1,11 @@ +from torch._higher_order_ops.associative_scan import associative_scan +from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized, + auto_functionalized_v2, +) from torch._higher_order_ops.cond import cond +from torch._higher_order_ops.effects import with_effects +from torch._higher_order_ops.executorch_call_delegate import executorch_call_delegate from torch._higher_order_ops.flex_attention import ( flex_attention, flex_attention_backward, @@ -6,9 +13,19 @@ from torch._higher_order_ops.flex_attention import ( from torch._higher_order_ops.foreach_map import _foreach_map, foreach_map from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._higher_order_ops.invoke_subgraph import invoke_subgraph +from torch._higher_order_ops.out_dtype import out_dtype from torch._higher_order_ops.prim_hop_base import PrimHOPBase +from torch._higher_order_ops.run_const_graph import run_const_graph from torch._higher_order_ops.scan import scan +from torch._higher_order_ops.strict_mode import strict_mode +from torch._higher_order_ops.torchbind import call_torchbind from torch._higher_order_ops.while_loop import while_loop +from torch._higher_order_ops.wrap import ( + tag_activation_checkpoint, + wrap_activation_checkpoint, + wrap_with_autocast, + wrap_with_set_grad_enabled, +) __all__ = [ @@ -22,4 +39,17 @@ __all__ = [ "PrimHOPBase", "foreach_map", "_foreach_map", + "with_effects", + "tag_activation_checkpoint", + "auto_functionalized", + "auto_functionalized_v2", + "associative_scan", + "out_dtype", + "executorch_call_delegate", + "call_torchbind", + "run_const_graph", + "wrap_with_set_grad_enabled", + "wrap_with_autocast", + "wrap_activation_checkpoint", + "strict_mode", ] diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index b43859cbb83..98088499263 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -16,7 +16,6 @@ from torch._higher_order_ops.utils import ( reenter_make_fx, unique_graph_id, ) -from torch._inductor.utils import is_pointwise_use from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( @@ -342,6 +341,8 @@ def generic_associative_scan(operator, leaves, dim=0): def trace_associative_scan( proxy_mode, func_overload, combine_fn: Callable, xs: list[torch.Tensor] ): + from torch._inductor.utils import is_pointwise_use + with disable_proxy_modes_tracing(): sample_xs = [first_slice_copy(x) for x in itertools.chain(xs, xs)] combine_graph = reenter_make_fx(combine_fn)(*sample_xs) diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index 9310c55ddac..604c545e586 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -7,7 +7,6 @@ from typing import Optional from torch._logging import warning_once from torch._ops import HigherOrderOperator from torch.types import _dtype -from torch.utils.checkpoint import checkpoint, CheckpointPolicy log = logging.getLogger(__name__) @@ -120,6 +119,8 @@ class WrapActivationCheckpoint(HigherOrderOperator): kwargs["preserve_rng_state"] = False # Using interpreter allows preservation of metadata through torch.compile stack. with fx_traceback.preserve_node_meta(): + from torch.utils.checkpoint import checkpoint + return checkpoint(Interpreter(function).run, *args, **kwargs) @@ -167,6 +168,8 @@ class TagActivationCheckpoint(HigherOrderOperator): We do sorting to ensure same graph from run to run for better debuggability. It is not required for correctness. """ + from torch.utils.checkpoint import checkpoint + ckpt_signature = inspect.signature(checkpoint) checkpoint_keys = set() for name in ckpt_signature.parameters: @@ -186,6 +189,8 @@ class TagActivationCheckpoint(HigherOrderOperator): return checkpoint_kwargs, gmod_kwargs def tag_nodes(self, gmod, is_sac): + from torch.utils.checkpoint import CheckpointPolicy + unique_graph_id = next(uid) for node in gmod.graph.nodes: if node.op in ("call_function", "call_method", "call_module"): @@ -226,6 +231,8 @@ Please make sure the checkpointed region does not contain in-place ops (e.g. tor gmod = self.tag_nodes(gmod, is_sac=True) # Using interpreter allows preservation of metadata through torch.compile stack. with fx_traceback.preserve_node_meta(): + from torch.utils.checkpoint import checkpoint + return checkpoint(Interpreter(gmod).run, *args, **kwargs) else: gmod = self.tag_nodes(gmod, is_sac=False) diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 8267a25023d..5e8c97b5e13 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -77,6 +77,16 @@ FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [ "run_with_rng_state", "out_dtype", "trace_wrapped", + 'tag_activation_checkpoint', + 'executorch_call_delegate', + 'wrap', + 'wrap_with_set_grad_enabled', + 'auto_functionalized_v2', + 'associative_scan', + 'wrap_with_autocast', + 'wrap_activation_checkpoint', + 'run_const_graph', + 'auto_functionalized', "map", # T183144629 "map_impl", "with_effects",