Require that all HOPs be imported at import torch time (#145939)

E.g. torch.ops.higher_order.cond does not exist until it is imported,
which is bad if it shows up in an FX graph or is used in some code
somewhere.

This PR also makes some more HOPs get imported at `import torch` time.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145939
Approved by: https://github.com/ydwu4
ghstack dependencies: #145938
This commit is contained in:
rzou 2025-01-29 10:38:28 -08:00 committed by PyTorch MergeBot
parent 2141c1aebe
commit 1e57154af3
5 changed files with 108 additions and 2 deletions

View file

@ -1,4 +1,8 @@
# Owner(s): ["module: higher order operators"]
import importlib
import pkgutil
import torch
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.hop_db import (
FIXME_hop_that_doesnt_have_opinfo_test_allowlist,
@ -6,9 +10,21 @@ from torch.testing._internal.hop_db import (
)
def do_imports():
for mod in pkgutil.walk_packages(
torch._higher_order_ops.__path__, "torch._higher_order_ops."
):
modname = mod.name
importlib.import_module(modname)
do_imports()
@skipIfTorchDynamo("not applicable")
class TestHOPInfra(TestCase):
def test_all_hops_have_opinfo(self):
"""All HOPs should have an OpInfo in torch/testing/_internal/hop_db.py"""
from torch._ops import _higher_order_ops
hops_that_have_op_info = {k.name for k in hop_db}
@ -28,6 +44,48 @@ class TestHOPInfra(TestCase):
f"Missing hop_db OpInfo entries for {missing_ops}, please add them to torch/testing/_internal/hop_db.py",
)
def test_all_hops_are_imported(self):
"""All HOPs should be listed in torch._higher_order_ops.__all__
Some constraints (see test_testing.py::TestImports)
- Sympy must be lazily imported
- Dynamo must be lazily imported
"""
imported_hops = torch._higher_order_ops.__all__
registered_hops = torch._ops._higher_order_ops.keys()
# Please don't add anything here.
# We want to ensure that all HOPs are imported at "import torch" time.
# It is bad if someone tries to access torch.ops.higher_order.cond
# and it doesn't exist (this may happen if your HOP isn't imported at
# "import torch" time).
FIXME_ALLOWLIST = {
"autograd_function_apply",
"run_with_rng_state",
"map_impl",
"_export_tracepoint",
"run_and_save_rng_state",
"map",
"custom_function_call",
"trace_wrapped",
"triton_kernel_wrapper_functional",
"triton_kernel_wrapper_mutation",
"wrap", # Really weird failure -- importing this causes Dynamo to choke on checkpoint
}
not_imported_hops = registered_hops - imported_hops
not_imported_hops = not_imported_hops - FIXME_ALLOWLIST
self.assertEqual(
not_imported_hops,
set(),
msg="All HOPs must be listed under torch/_higher_order_ops/__init__.py's __all__.",
)
def test_imports_from_all_work(self):
"""All APIs listed in torch._higher_order_ops.__all__ must be importable"""
stuff = torch._higher_order_ops.__all__
for attr in stuff:
getattr(torch._higher_order_ops, attr)
if __name__ == "__main__":
run_tests()

View file

@ -1,4 +1,11 @@
from torch._higher_order_ops.associative_scan import associative_scan
from torch._higher_order_ops.auto_functionalize import (
auto_functionalized,
auto_functionalized_v2,
)
from torch._higher_order_ops.cond import cond
from torch._higher_order_ops.effects import with_effects
from torch._higher_order_ops.executorch_call_delegate import executorch_call_delegate
from torch._higher_order_ops.flex_attention import (
flex_attention,
flex_attention_backward,
@ -6,9 +13,19 @@ from torch._higher_order_ops.flex_attention import (
from torch._higher_order_ops.foreach_map import _foreach_map, foreach_map
from torch._higher_order_ops.hints_wrap import hints_wrapper
from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
from torch._higher_order_ops.out_dtype import out_dtype
from torch._higher_order_ops.prim_hop_base import PrimHOPBase
from torch._higher_order_ops.run_const_graph import run_const_graph
from torch._higher_order_ops.scan import scan
from torch._higher_order_ops.strict_mode import strict_mode
from torch._higher_order_ops.torchbind import call_torchbind
from torch._higher_order_ops.while_loop import while_loop
from torch._higher_order_ops.wrap import (
tag_activation_checkpoint,
wrap_activation_checkpoint,
wrap_with_autocast,
wrap_with_set_grad_enabled,
)
__all__ = [
@ -22,4 +39,17 @@ __all__ = [
"PrimHOPBase",
"foreach_map",
"_foreach_map",
"with_effects",
"tag_activation_checkpoint",
"auto_functionalized",
"auto_functionalized_v2",
"associative_scan",
"out_dtype",
"executorch_call_delegate",
"call_torchbind",
"run_const_graph",
"wrap_with_set_grad_enabled",
"wrap_with_autocast",
"wrap_activation_checkpoint",
"strict_mode",
]

View file

@ -16,7 +16,6 @@ from torch._higher_order_ops.utils import (
reenter_make_fx,
unique_graph_id,
)
from torch._inductor.utils import is_pointwise_use
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
@ -342,6 +341,8 @@ def generic_associative_scan(operator, leaves, dim=0):
def trace_associative_scan(
proxy_mode, func_overload, combine_fn: Callable, xs: list[torch.Tensor]
):
from torch._inductor.utils import is_pointwise_use
with disable_proxy_modes_tracing():
sample_xs = [first_slice_copy(x) for x in itertools.chain(xs, xs)]
combine_graph = reenter_make_fx(combine_fn)(*sample_xs)

View file

@ -7,7 +7,6 @@ from typing import Optional
from torch._logging import warning_once
from torch._ops import HigherOrderOperator
from torch.types import _dtype
from torch.utils.checkpoint import checkpoint, CheckpointPolicy
log = logging.getLogger(__name__)
@ -120,6 +119,8 @@ class WrapActivationCheckpoint(HigherOrderOperator):
kwargs["preserve_rng_state"] = False
# Using interpreter allows preservation of metadata through torch.compile stack.
with fx_traceback.preserve_node_meta():
from torch.utils.checkpoint import checkpoint
return checkpoint(Interpreter(function).run, *args, **kwargs)
@ -167,6 +168,8 @@ class TagActivationCheckpoint(HigherOrderOperator):
We do sorting to ensure same graph from run to run for better
debuggability. It is not required for correctness.
"""
from torch.utils.checkpoint import checkpoint
ckpt_signature = inspect.signature(checkpoint)
checkpoint_keys = set()
for name in ckpt_signature.parameters:
@ -186,6 +189,8 @@ class TagActivationCheckpoint(HigherOrderOperator):
return checkpoint_kwargs, gmod_kwargs
def tag_nodes(self, gmod, is_sac):
from torch.utils.checkpoint import CheckpointPolicy
unique_graph_id = next(uid)
for node in gmod.graph.nodes:
if node.op in ("call_function", "call_method", "call_module"):
@ -226,6 +231,8 @@ Please make sure the checkpointed region does not contain in-place ops (e.g. tor
gmod = self.tag_nodes(gmod, is_sac=True)
# Using interpreter allows preservation of metadata through torch.compile stack.
with fx_traceback.preserve_node_meta():
from torch.utils.checkpoint import checkpoint
return checkpoint(Interpreter(gmod).run, *args, **kwargs)
else:
gmod = self.tag_nodes(gmod, is_sac=False)

View file

@ -77,6 +77,16 @@ FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
"run_with_rng_state",
"out_dtype",
"trace_wrapped",
'tag_activation_checkpoint',
'executorch_call_delegate',
'wrap',
'wrap_with_set_grad_enabled',
'auto_functionalized_v2',
'associative_scan',
'wrap_with_autocast',
'wrap_activation_checkpoint',
'run_const_graph',
'auto_functionalized',
"map", # T183144629
"map_impl",
"with_effects",