From 9012e7a62f756f7cf42e73d7dd375f8c468a0c66 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 2 Dec 2024 16:05:14 +0000 Subject: [PATCH] Revert "[dynamo][pytree][1/N] make CXX pytree traceable: `tree_iter` / `tree_leaves` (#137397)" This reverts commit 07850bb2c1771ba3f5578b0aa85792e5cd70de1c. Reverted https://github.com/pytorch/pytorch/pull/137397 on behalf of https://github.com/atalman due to Failing internal test ([comment](https://github.com/pytorch/pytorch/pull/137397#issuecomment-2511934283)) --- test/dynamo/test_misc.py | 57 ++++++++++-------- torch/_dynamo/guards.py | 7 +-- torch/_dynamo/polyfills/__init__.py | 1 - torch/_dynamo/polyfills/loader.py | 1 - torch/_dynamo/polyfills/pytree.py | 89 ----------------------------- torch/_dynamo/trace_rules.py | 1 - torch/utils/_cxx_pytree.py | 40 +++++++------ 7 files changed, 57 insertions(+), 139 deletions(-) delete mode 100644 torch/_dynamo/polyfills/pytree.py diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 8947e577010..f78f1f4e45c 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -32,7 +32,7 @@ import torch import torch._dynamo.testing import torch._inductor.test_case import torch.onnx.operators -import torch.utils._pytree as python_pytree +import torch.utils._pytree as pytree import torch.utils.cpp_extension from torch import Tensor from torch._C import FileCheck @@ -89,11 +89,9 @@ from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.logging_utils import logs_to_string -HAS_OPTREE = python_pytree._cxx_pytree_exists +HAS_OPTREE = importlib.util.find_spec("optree") if HAS_OPTREE: - import torch.utils._cxx_pytree as cxx_pytree -else: - cxx_pytree = None + import optree MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"]) T = typing.TypeVar("T") @@ -295,9 +293,9 @@ class MiscTests(torch._inductor.test_case.TestCase): @unittest.skipIf(not HAS_OPTREE, "missing optree package") def test_optree_graph_break_message(self): - import optree - - @torch.compile(backend="eager") + @torch.compile( + backend="eager", + ) def fn(x): d = {"a": 1} optree.tree_flatten(d) @@ -8678,9 +8676,9 @@ def ___make_guard_fn(): def test_tracing_py_tree(self): def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) + flat_xs, spec = pytree.tree_flatten(xs) res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) + return pytree.tree_unflatten(res, spec) xs = [torch.tensor(i) for i in range(3)] @@ -8690,10 +8688,12 @@ def ___make_guard_fn(): self.assertEqual(counter.op_count, 3) def test_tracing_nested_py_tree(self): + import torch.utils._pytree as pytree + def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) + flat_xs, spec = pytree.tree_flatten(xs) res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) + return pytree.tree_unflatten(res, spec) xs = [torch.tensor(i) for i in range(3)] xsl = [xs, xs, xs, xs] @@ -8706,10 +8706,12 @@ def ___make_guard_fn(): self.assertEqual(counter.op_count, 12) def test_tracing_nested_py_tree_tuples(self): + import torch.utils._pytree as pytree + def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) + flat_xs, spec = pytree.tree_flatten(xs) res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) + return pytree.tree_unflatten(res, spec) xs = [torch.tensor(i) for i in range(3)] xsl = (xs, xs, xs, xs) @@ -8722,10 +8724,12 @@ def ___make_guard_fn(): self.assertEqual(counter.op_count, 12) def test_tracing_nested_py_tree_dicts(self): + import torch.utils._pytree as pytree + def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) + flat_xs, spec = pytree.tree_flatten(xs) res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) + return pytree.tree_unflatten(res, spec) xs = [torch.tensor(i) for i in range(3)] xsl = { @@ -8758,10 +8762,12 @@ def ___make_guard_fn(): self.assertEqual(counter.op_count, 2) def test_tracing_nested_py_tree_mixed_all(self): + import torch.utils._pytree as pytree + def fn(xs): - flat_xs, spec = python_pytree.tree_flatten(xs) + flat_xs, spec = pytree.tree_flatten(xs) res = [x.clone() for x in flat_xs] - return python_pytree.tree_unflatten(res, spec) + return pytree.tree_unflatten(res, spec) xs = [torch.tensor(i) for i in range(3)] xsa = (xs, xs) @@ -8806,12 +8812,13 @@ def ___make_guard_fn(): self.assertEqual(cnt.frame_count, 2) def test_tracing_py_tree_tensor_subclass(self): + import torch.utils._pytree as pytree from torch.testing._internal.two_tensor import TwoTensor from torch.utils.checkpoint import checkpoint def fn(xs): nested_xs = [[xs]] - flat_xs, spec = python_pytree.tree_flatten(xs) + flat_xs, spec = pytree.tree_flatten(xs) return flat_xs[0].clone() # use checkpoint to trigger a "sourceless" tensor subclass @@ -8826,11 +8833,13 @@ def ___make_guard_fn(): self.assertEqual(counter.op_count, 2) def test_tracing_tree_map_only(self): + import torch.utils._pytree as pytree + def fn(xs): def mapper(x): return x.clone() - y = python_pytree.tree_map_only(torch.Tensor, mapper, xs) + y = pytree.tree_map_only(torch.Tensor, mapper, xs) return y xs = [torch.tensor(i) for i in range(3)] + ["hi"] @@ -10184,9 +10193,7 @@ def ___make_guard_fn(): self.assertEqual(actual, expected) def test_pytree_tree_leaves(self): - implemtations = [("python", python_pytree)] - if cxx_pytree is not None: - implemtations.append(("cxx", cxx_pytree)) + implemtations = [("python", pytree)] for name, module in implemtations: with self.subTest(f"pytree implement: {name}"): @@ -10218,7 +10225,7 @@ def ___make_guard_fn(): self.assertEqual(actual, expected) def test_pytree_tree_flatten_unflatten(self): - implemtations = [("python", python_pytree)] + implemtations = [("python", pytree)] for name, module in implemtations: with self.subTest(f"pytree implement: {name}"): @@ -10267,7 +10274,7 @@ def ___make_guard_fn(): self.assertEqual(actual, expected) def test_pytree_tree_map(self): - implemtations = [("python", python_pytree)] + implemtations = [("python", pytree)] for name, module in implemtations: with self.subTest(f"pytree implement: {name}"): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index dd18ad98d09..b3497f6ad85 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2080,11 +2080,10 @@ class GuardBuilder(GuardBuilderBase): obj_ref = None # Not necessary to have weakref for Enum type, but there is a bug that # makes hasattr(guarded_object.__class__, "__weakref__") return True. - supports_weakref = ( - getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0 - ) # See D64140537 for why we are checking for tuple. - if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)): + if hasattr(guarded_object.__class__, "__weakref__") and not isinstance( + guarded_object, (enum.Enum, tuple) + ): obj_ref = weakref.ref(guarded_object) guard.set_export_info( diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 8beb32a1b9e..c4cc3624a03 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -23,7 +23,6 @@ if TYPE_CHECKING: itertools as itertools, operator as operator, os as os, - pytree as pytree, sys as sys, ) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 8dd9eddee14..c67a5d907cf 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -18,7 +18,6 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( "itertools", "operator", "os", - "pytree", "sys", ) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py deleted file mode 100644 index 5538792c437..00000000000 --- a/torch/_dynamo/polyfills/pytree.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Python polyfills for torch.utils.pytree -""" - -from __future__ import annotations - -from typing import Any, Callable, Iterable, TYPE_CHECKING - -import torch.utils._pytree as python_pytree - -from ..decorators import substitute_in_graph - - -if TYPE_CHECKING: - from torch.utils._cxx_pytree import PyTree - - -__all__: list[str] = [] - - -if python_pytree._cxx_pytree_exists: - import optree - import optree._C - - import torch.utils._cxx_pytree as cxx_pytree - - @substitute_in_graph( - optree._C.is_dict_insertion_ordered, - can_constant_fold_through=True, - ) - def _(*args: Any, **kwargs: Any) -> bool: - # In namespace 'torch', the dictionary is always traversed in insertion order. - # This function returns True. - raise ValueError( - "Should not be called directly " - "because the original function will be called in the constant fold path." - ) - - __name = "" - for __name in ( - "is_namedtuple", - "is_namedtuple_class", - "is_namedtuple_instance", - "is_structseq", - "is_structseq_class", - "is_structseq_instance", - "namedtuple_fields", - "structseq_fields", - ): - __func = getattr(optree, __name) - substitute_in_graph(__func, can_constant_fold_through=True)( - __func.__python_implementation__ - ) - del __func - del __name - - @substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False) - def tree_iter( - tree: PyTree, - is_leaf: Callable[[PyTree], bool] | None = None, - ) -> Iterable[Any]: - stack = [tree] - while stack: - node = stack.pop() - if node is None or (is_leaf is not None and is_leaf(node)): - yield node - continue - if optree.register_pytree_node.get(type(node), namespace="torch") is None: # type: ignore[attr-defined] - yield node - continue - - children, *_ = optree.tree_flatten_one_level( - node, - is_leaf=is_leaf, - none_is_leaf=True, - namespace="torch", - ) - stack.extend(reversed(children)) - - __all__ += ["tree_iter"] - - @substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True) - def tree_leaves( - tree: PyTree, - is_leaf: Callable[[PyTree], bool] | None = None, - ) -> list[Any]: - return list(tree_iter(tree, is_leaf=is_leaf)) - - __all__ += ["tree_leaves"] diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 1d0954dd83b..d3fae78a630 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3310,7 +3310,6 @@ MOD_INLINELIST = [ "torch.testing", "torch.utils._content_store", "torch.utils._contextlib", - "torch.utils._cxx_pytree", "torch.utils._device", "torch.utils._foreach_utils", "torch.utils._python_dispatch", diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index baf272c9426..7b7cdc8a7e2 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -30,10 +30,10 @@ from typing import ( from typing_extensions import deprecated import optree -from optree import PyTreeSpec as TreeSpec # direct import for type annotations +from optree import PyTreeSpec # direct import for type annotations -import torch.utils._pytree as python_pytree -from torch.utils._pytree import KeyEntry as KeyEntry +import torch.utils._pytree as _pytree +from torch.utils._pytree import KeyEntry __all__ = [ @@ -79,6 +79,7 @@ R = TypeVar("R") Context = Any PyTree = Any +TreeSpec = PyTreeSpec FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree] @@ -150,7 +151,9 @@ def register_pytree_node( from_dumpable_context=from_dumpable_context, ) - python_pytree._private_register_pytree_node( + from . import _pytree as python + + python._private_register_pytree_node( cls, flatten_fn, unflatten_fn, @@ -868,19 +871,24 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: f"treespec_dumps(spec): Expected `spec` to be instance of " f"TreeSpec but got item of type {type(treespec)}." ) + from ._pytree import ( + tree_structure as _tree_structure, + treespec_dumps as _treespec_dumps, + ) - dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec) - orig_treespec = python_pytree.tree_structure(dummy_tree) - return python_pytree.treespec_dumps(orig_treespec, protocol=protocol) + orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec)) + return _treespec_dumps(orig_treespec, protocol=protocol) def treespec_loads(serialized: str) -> TreeSpec: """Deserialize a treespec from a JSON string.""" - orig_treespec = python_pytree.treespec_loads(serialized) - dummy_tree = python_pytree.tree_unflatten( - [0] * orig_treespec.num_leaves, - orig_treespec, + from ._pytree import ( + tree_unflatten as _tree_unflatten, + treespec_loads as _treespec_loads, ) + + orig_treespec = _treespec_loads(serialized) + dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec) treespec = tree_structure(dummy_tree) return treespec @@ -994,10 +1002,6 @@ def key_get(obj: Any, kp: KeyPath) -> Any: raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") -with python_pytree._NODE_REGISTRY_LOCK: - python_pytree._cxx_pytree_imported = True - args, kwargs = (), {} # type: ignore[var-annotated] - for args, kwargs in python_pytree._cxx_pytree_pending_imports: - _private_register_pytree_node(*args, **kwargs) - python_pytree._cxx_pytree_pending_imports.clear() - del args, kwargs +_pytree._cxx_pytree_imported = True +for args, kwargs in _pytree._cxx_pytree_pending_imports: + _private_register_pytree_node(*args, **kwargs)