Revert "Consolidate SymDispatchMode into ProxyTensorMode (#132674)"

This reverts commit ffdf48e63b.

Reverted https://github.com/pytorch/pytorch/pull/132674 on behalf of https://github.com/PaliC due to We need to now revert https://github.com/pytorch/pytorch/pull/132216 in OSS and there is a dependency on this pr ([comment](https://github.com/pytorch/pytorch/pull/132674#issuecomment-2274062785))
This commit is contained in:
PyTorch MergeBot 2024-08-07 18:25:32 +00:00
parent 9d476fee53
commit a9ff190867
9 changed files with 136 additions and 90 deletions

1
.github/labeler.yml vendored
View file

@ -29,6 +29,7 @@
- torch/fx/experimental/recording.py
- torch/fx/experimental/sym_node.py
- torch/fx/experimental/validator.py
- torch/fx/experimental/_sym_dispatch_mode.py
- torch/fx/experimental/proxy_tensor.py
- test/distributed/_tensor/test_dtensor_compile.py
- test/distributed/tensor/parallel/test_fsdp_2d_parallel.py

View file

@ -850,6 +850,7 @@ coverage_ignore_functions = [
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"make_fx",
"maybe_disable_fake_tensor_mode",
"maybe_handle_decomp",
"proxy_call",

View file

@ -51,17 +51,3 @@ torch.fx.experimental.symbolic_shapes
compute_unbacked_bindings
rebind_unbacked
resolve_unbacked_bindings
torch.fx.experimental.proxy_tensor
-------------------------------------
.. currentmodule:: torch.fx.experimental.proxy_tensor
.. automodule:: torch.fx.experimental.proxy_tensor
.. autosummary::
:toctree: generated
:nosignatures:
make_fx
handle_sym_dispatch
get_proxy_mode

View file

@ -1143,6 +1143,7 @@ API Reference
.. py:module:: torch.fx.experimental.normalize
.. py:module:: torch.fx.experimental.optimization
.. py:module:: torch.fx.experimental.partitioner_utils
.. py:module:: torch.fx.experimental.proxy_tensor
.. py:module:: torch.fx.experimental.recording
.. py:module:: torch.fx.experimental.refinement_types
.. py:module:: torch.fx.experimental.rewriter

View file

@ -112,6 +112,7 @@ class AutogradCompilerInstance:
# TODO(jansel): are all these modes needed?
self.stack.enter_context(decompose({}))
self.stack.enter_context(self.fake_tensor_mode)
self.stack.enter_context(self.proxy_mode.sym_mode)
self.stack.enter_context(self.proxy_mode)
self.stack.enter_context(disable_autocast_cache())
self.stack.enter_context(preserve_node_meta())

View file

@ -25,6 +25,7 @@ from weakref import ReferenceType
import torch
import torch._logging
import torch.fx.experimental._sym_dispatch_mode
from torch._C._dynamo.guards import GlobalStateGuard
from torch._dynamo.distributed import get_compile_pg
from torch._guards import compile_context, CompileContext, CompileId, tracing
@ -1227,7 +1228,9 @@ class CatchErrorsWrapper:
frame, cache_entry, self.hooks, frame_state
)
with compile_lock, _disable_current_modes():
with (
compile_lock
), _disable_current_modes(), torch.fx.experimental._sym_dispatch_mode.disable_sym_dispatch():
# skip=1: skip this frame
return self._torchdynamo_orig_callable(
frame, cache_entry, self.hooks, frame_state, skip=1

View file

@ -0,0 +1,72 @@
# mypy: allow-untyped-defs
import contextlib
from typing import List, Optional, Type
__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
# SymDispatchMode gets invoked whenever an operation is processed on
# a PySymInt. When this occurs, you get called at __sym_dispatch__
# with the operation in question. This is symmetric to TorchDispatchMode
# but with some caveats:
#
# - In TorchDispatchMode, you get the same arguments as what a user
# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
# you get (a, b) as args to your call. In SymDispatchMode, if
# you call a + b (where a and b are SymInts), you will get
# (a.node, b.node) as your args (these are PySymInts)
#
# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
# So you have to manually call Tracer/create_node to write into
# the graph. See ProxySymDispatchMode for an example
#
class SymDispatchMode:
def __sym_dispatch__(self, func, types, args, kwargs):
raise NotImplementedError
def __enter__(self):
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
if hasattr(self, "inner"):
raise RuntimeError(
f"{self} has already been used as a mode. Please use a fresh version"
)
else:
self.inner = old
SYM_FUNCTION_MODE = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = self.inner
def handle_sym_dispatch(func, args, kwargs):
global SYM_FUNCTION_MODE
mode = sym_function_mode()
assert mode
SYM_FUNCTION_MODE = mode.inner
try:
# TODO: properly compute types
types: List[Type] = []
return mode.__sym_dispatch__(func, types, args, kwargs)
finally:
SYM_FUNCTION_MODE = mode
def sym_function_mode():
return SYM_FUNCTION_MODE
@contextlib.contextmanager
def disable_sym_dispatch():
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = None
try:
yield
finally:
SYM_FUNCTION_MODE = old

View file

@ -22,13 +22,13 @@ import warnings
import weakref
from ._backward_state import BackwardState
from ._sym_dispatch_mode import SymDispatchMode
from .sym_node import SymNode
from torch.utils._thunk import Thunk
from collections import defaultdict
from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack
from dataclasses import dataclass
from torch import SymInt, SymBool, Tensor
import torch._ops
from torch._dispatch.python import enable_python_dispatcher
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
@ -59,10 +59,7 @@ if TYPE_CHECKING:
from torch.fx._symbolic_trace import PHBase
from torch.types import IntLikeType
__all__ = [
"PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter",
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch"
]
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"]
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
@ -1009,10 +1006,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
class ProxyTorchDispatchMode(TorchDispatchMode):
# Ensure this is read-only; this exists only for legacy reasons
@property
def enable_tracing(self) -> bool:
return True
_managers: List[AbstractContextManager]
def __init__(
self,
@ -1026,9 +1020,12 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
super().__init__(dk)
self.tracer = tracer
self.tracing_mode = tracing_mode
self.enable_tracing = True
self.pre_dispatch = pre_dispatch
self._allow_fake_constant = _allow_fake_constant
self._error_on_data_dependent_ops = _error_on_data_dependent_ops
self.sym_mode = ProxySymDispatchMode(tracer)
self._managers = []
# Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence.
self._mode_key = torch._C._TorchDispatchModeKey.PROXY
@ -1048,10 +1045,14 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None
) -> object:
with set_original_aten_op(func):
with self.sym_mode.enable(False), set_original_aten_op(func):
return self.inner_torch_dispatch(func, types, args, kwargs)
def __enter__(self) -> Self:
# sym mode first, then us...
m = self.sym_mode.enable(True)
self._managers.append(m)
m.__enter__()
# Stash and store the previous proxy mode (there may or may not be one)
maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
self.enter_stack.append(maybe_prev_proxy_mode)
@ -1063,6 +1064,8 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType]
) -> Optional[bool]:
m = self._managers.pop()
# ...exit us first, then sym mode
b = super().__exit__(exc_type, exc_value, traceback)
# Re-enable the previous proxy mode, if there was one.
@ -1070,7 +1073,11 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
if mb_previous_proxy_mode is not None:
_push_mode(mb_previous_proxy_mode)
return b
if not b:
return m.__exit__(exc_type, exc_value, traceback)
else:
return m.__exit__(None, None, None)
def inner_torch_dispatch(
self,
@ -1081,6 +1088,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
) -> object:
kwargs = kwargs or {}
if not self.enable_tracing:
return func(*args, **kwargs)
if func in (prim.device.default,):
return func(*args, **kwargs)
@ -1090,6 +1100,25 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def is_infra_mode(cls) -> bool:
return True
class ProxySymDispatchMode(SymDispatchMode):
def __init__(self, tracer: _ProxyTracer) -> None:
super().__init__()
self.tracer = tracer
# When false, we don't trace operations. If you do this, you MUST
# call track_tensor/track_tensor_tree on all results of the operation
# to ensure we can adequately track the results
self.enable_tracing = True
@contextmanager
def enable(self, b: bool) -> Generator[None, None, None]:
old = self.enable_tracing
self.enable_tracing = b
try:
yield
finally:
self.enable_tracing = old
def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy:
n_args = tuple(
get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a
@ -1110,6 +1139,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
args: Tuple[object, ...],
kwargs: Dict[str, object]
) -> object:
if not self.enable_tracing:
return func(*args, **kwargs)
# Peephole optimize multiply by one
# NB: be careful not to trigger guards here!
if func == operator.mul:
@ -1695,6 +1727,7 @@ class _MakefxTracer:
stack.enter_context(self.fake_tensor_mode)
stack.enter_context(self.python_dispatcher_mode)
stack.enter_context(self.proxy_function_mode)
stack.enter_context(proxy_mode.sym_mode)
stack.enter_context(self.torch_fn_metadata_mode)
stack.enter_context(proxy_mode)
stack.enter_context(disable_autocast_cache())
@ -1754,14 +1787,9 @@ def make_fx(
_allow_fake_constant: bool = False,
_error_on_data_dependent_ops: bool = True) -> Callable[..., GraphModule]:
"""
Given a function f, return a new function which when executed with valid
arguments to f, returns an FX GraphModule representing the set of operations that
were executed during the course of execution.
"""
assert tracing_mode in ["real", "fake", "symbolic"]
make_fx_tracer = _MakefxTracer(
decomposition_table,
tracing_mode,
@ -1782,38 +1810,8 @@ def get_torch_dispatch_modes() -> List[TorchDispatchMode]:
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
# TODO: this is a legacy name, there is only ever one proxy mode as it's an
# infra mode
def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
return get_proxy_mode()
def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
"""
Current the currently active proxy tracing mode, or None if
we are not currently tracing. This includes pre-dispatch proxy
tracing.
"""
pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
assert pre_dispatch_mode is None or mode is None, f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
return pre_dispatch_mode or mode
def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) -> R:
"""
Call into the currently active proxy tracing mode to do a
SymInt/SymFloat/SymBool dispatch trace on a function that operates on
these arguments.
"""
mode = get_proxy_mode()
assert mode
# Have to do it manually, because we're not doing the normal torch
# dispatch machinery which disables it for us
with disable_proxy_modes_tracing():
# TODO: properly compute types
types: List[Type] = []
return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value]
def get_innermost_proxy_mode() -> ProxyTorchDispatchMode:
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
@contextmanager

View file

@ -31,6 +31,10 @@ from torch import ( # noqa: F401
SymFloat,
SymInt,
)
from torch.fx.experimental._sym_dispatch_mode import (
handle_sym_dispatch,
sym_function_mode,
)
if TYPE_CHECKING:
@ -1051,10 +1055,6 @@ def _make_node_magic(method, func):
method_attr = method
def binary_magic_impl(self, other):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
@ -1067,7 +1067,7 @@ def _make_node_magic(method, func):
if alternate_impl and out_hint is not None:
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
if get_proxy_mode():
if sym_function_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
)
@ -1129,14 +1129,10 @@ def _make_node_magic(method, func):
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
def unary_magic_impl(self):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
if get_proxy_mode():
if sym_function_mode():
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
# TODO: consider constant prop here
expr = self.expr
@ -1171,14 +1167,10 @@ def _make_node_magic(method, func):
elif method == "sym_ite":
def sym_ite_impl(pred_node, then_node, else_node):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
out_hint = then_node.hint if pred_node.hint else else_node.hint
if get_proxy_mode():
if sym_function_mode():
return to_node(
pred_node,
handle_sym_dispatch(
@ -1216,14 +1208,10 @@ def _make_node_magic(method, func):
elif method == "round":
def round_impl(self, ndigits=None):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
op = builtins.round
if get_proxy_mode():
if sym_function_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
)
@ -1268,13 +1256,8 @@ def _make_node_sizes_strides(method, func):
# NB: don't LRU cache, lots of arguments
def sizes_strides_impl(self, sizes, strides):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
op = getattr(sys.modules[__name__], method)
if get_proxy_mode():
if sym_function_mode():
return to_node(
self,
handle_sym_dispatch(