From c3c03e7cb88ea6366a8144fa56985e50d9627804 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 21 Jun 2023 00:43:00 +0000 Subject: [PATCH] Reland of https://github.com/pytorch/pytorch/pull/101818 (#103888) Original PR broke internal This reverts commit 5ed618132f466440ad76c884240e07796c7e2c6b. Pull Request resolved: https://github.com/pytorch/pytorch/pull/103888 Approved by: https://github.com/albanD --- aten/src/ATen/core/PythonFallbackKernel.cpp | 13 +++++ c10/core/DispatchKey.cpp | 4 ++ c10/core/DispatchKey.h | 6 +++ test/dynamo/test_export.py | 6 +-- test/test_proxy_tensor.py | 17 +++++-- torch/_C/__init__.pyi.in | 5 ++ torch/_dispatch/python.py | 3 +- torch/_dynamo/eval_frame.py | 15 +++--- torch/csrc/autograd/init.cpp | 6 +++ torch/csrc/utils/python_dispatch.cpp | 1 + torch/fx/experimental/proxy_tensor.py | 56 +++++++-------------- torch/utils/_python_dispatch.py | 18 ++----- 12 files changed, 85 insertions(+), 65 deletions(-) diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index 2530f2184a6..688dfb3c4da 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -107,6 +107,15 @@ void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySe op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack); } +// The PreDispatch key gets a no-op fallback that just redispatches. +// The main way this key is used is that we can register a mode to it from python (e.g. TorchProxyDispatchMode, for pre_dispatch tracing) +// Can't this be a fallthrough kernel, instead of a fallback that just no-ops and redispatches? +// Unfortunately, no: we need a real kernel that is not a fallthrough, in order for the PythonDispatcher to interpose on it. +// Alternatively, we could have hardcoded this kernel (in C++) to directly call in TorchProxyDispatchMode. +// Doing that in C++ is a pain though, so it's done in python using the PythonDispatcher for convenience. +void preDispatchFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { + op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PreDispatch), stack); +} } // anonymous namespace @@ -152,3 +161,7 @@ TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) { TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>()); } + +TORCH_LIBRARY_IMPL(_, PreDispatch, m) { + m.fallback(torch::CppFunction::makeFromBoxedFunction<&preDispatchFallback>()); +} diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 954f25aa12b..657815d6039 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -180,6 +180,9 @@ const char* toString(DispatchKey t) { case DispatchKey::TESTING_ONLY_GenericMode: return "TESTING_ONLY_GenericMode"; + case DispatchKey::PreDispatch: + return "PreDispatch"; + case DispatchKey::PythonDispatcher: return "PythonDispatcher"; @@ -300,6 +303,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { c10::DispatchKey::TESTING_ONLY_GenericWrapper}, {"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode}, {"PythonDispatcher", c10::DispatchKey::PythonDispatcher}, + {"PreDispatch", c10::DispatchKey::PreDispatch}, {"CPU", c10::DispatchKey::CPU}, {"CUDA", c10::DispatchKey::CUDA}, diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 40be3f9dc5d..85fabdb30c3 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -406,6 +406,12 @@ enum class DispatchKey : uint16_t { // for a usage example TESTING_ONLY_GenericMode, + // This key is used for pre-dispatch tracing in make_fx. + // It has lower priority than the PythonDispatcher key + // because we use the PythonDispatcher to intercept the key from python, + // and avoid having to implement it in C++. + PreDispatch, + // This is a bypass that allows you to skip running the C++ dispatcher // entirely PythonDispatcher, diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 4ee1884d59a..7a365d13de8 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -1828,9 +1828,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): inp = torch.randn(6, 7) self.assertEqual(gm(inp), f(inp)) - # pre_autograd seems to violate new fake tensor invariants - @unittest.expectedFailure - def test_pre_autograd_simple(self): + def test_pre_dispatch_simple(self): def f(x): y = torch.ones_like(x) return torch.matmul(x, y) @@ -1839,7 +1837,7 @@ class ExportTests(torch._dynamo.test_case.TestCase): f, torch.randn(5, 5), aten_graph=True, - pre_autograd=True, + pre_dispatch=True, tracing_mode="fake", ) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index b74a51f2c1a..862ca3676cf 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -147,7 +147,7 @@ class TestGenericProxyTensor(TestCase): r2 = f(*new_inps) self.assertEqual(r1, r2) - def test_pre_autograd_mode_stack(self): + def test_pre_dispatch_mode_stack(self): def f(a): b = torch.ones(4, 4) return torch.matmul(a, b) @@ -156,17 +156,28 @@ class TestGenericProxyTensor(TestCase): # This is annoying but expected: ones() never dispatches to the Autograd dispatch key, # so our mode never sees it - it goes directly to the BackendSelect key. inp = torch.ones(4, 4) - # Test that make_fx(pre_autograd=True) clears caches properly. + # Test that make_fx(pre_dispatch=True) clears caches properly. from torch._dispatch.python import enable_python_dispatcher with enable_python_dispatcher(): out1 = f(inp) - fx_g = make_fx(f, pre_autograd=True)(inp) + fx_g = make_fx(f, pre_dispatch=True)(inp) self.assertExpectedInline(fx_g.code.strip(), """\ def forward(self, a_1): ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False) matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None return matmul""") + def test_pre_dispatch_linear(self): + def f(a, b, c): + return torch.nn.functional.linear(a, b, c) + a = torch.ones(4, 4) + b = torch.ones(4, 4) + c = torch.ones(4) + fx_g = make_fx(f, pre_dispatch=True)(a, b, c) + out1 = f(a, b, c) + out2 = fx_g(a, b, c) + self.assertEqual(out1, out2) + def test_make_fx_simple(self): def f(x): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 796e4cda80d..6a36329b70e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1191,6 +1191,11 @@ class _DisablePythonDispatcher: def __enter__(self): ... def __exit__(self, exc_type, exc_value, traceback): ... +class _EnablePreDispatch: + def __init__(self): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + class _DisableFuncTorch: def __init__(self): ... def __enter__(self): ... diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index 3abf339b6fc..cc1e5337135 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -7,10 +7,11 @@ import itertools from typing import Iterator import torch._ops -__all__ = ['enable_python_dispatcher', 'no_python_dispatcher'] +__all__ = ['enable_python_dispatcher', 'no_python_dispatcher', 'enable_pre_dispatch'] no_python_dispatcher = torch._C._DisablePythonDispatcher enable_python_dispatcher = torch._C._EnablePythonDispatcher +enable_pre_dispatch = torch._C._EnablePreDispatch CROSSREF_FUNCTIONALIZE = False diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index b107fdd21f9..d7b661583ed 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -791,7 +791,7 @@ def export( f: Callable[..., Any], *args, aten_graph: bool = False, - pre_autograd: bool = False, + pre_dispatch: bool = False, decomposition_table: Optional[ Dict[torch._ops.OpOverload, Callable[..., Any]] ] = None, @@ -812,9 +812,10 @@ def export( aten_graph (bool): If True, exports a graph with ATen operators. If False, exports a graph with Python operators. Default is False. - pre_autograd (bool): If True, exports a graph with ATen operators, - but before autograd has run. This can be useful if you want to apply further tranformations - on a graph before running it through autograd. + pre_dispatch (bool): If True, exports a graph with ATen operators, + but before any logic in the PyTorch dispatcher has run. + This can be useful if you want to apply further tranformations on a graph before running it + through autograd, autocast, or any other functionalities that are integrated into the dispatcher. This flag is only valid if aten_graph=True is set. Default is False. @@ -848,8 +849,8 @@ def export( assert ( aten_graph ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" - if pre_autograd: - assert aten_graph, "pre_autograd=True can only be used when aten_graph=True" + if pre_dispatch: + assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" f = innermost_fn(f) call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f original_signature = inspect.signature(call_to_inspect) @@ -1018,7 +1019,7 @@ def export( decomposition_table=decomposition_table, tracing_mode="real", _allow_non_fake_inputs=True, - pre_autograd=pre_autograd, + pre_dispatch=pre_dispatch, _allow_fake_constant=_allow_fake_constant, )(*example_fake_inputs) except CondOpArgsMismatchError as e: diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index cfddb7e2b23..9ffd71197b8 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -98,6 +98,11 @@ struct EnablePythonDispatcher { c10::impl::PyInterpreter* old_; }; +struct EnablePreDispatch { + EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {} + c10::impl::IncludeDispatchKeyGuard guard_; +}; + } // namespace PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { @@ -419,6 +424,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { _C_m, "_EnablePythonDispatcher"); py_context_manager( _C_m, "_DisablePythonDispatcher"); + py_context_manager(_C_m, "_EnablePreDispatch"); py_context_manager_DEPRECATED(_C_m, "_DisableFuncTorch"); py_context_manager_DEPRECATED( _C_m, "_MultithreadingEnabled"); diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 89ba803182b..c1f3b74bbbf 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -561,6 +561,7 @@ void initDispatchBindings(PyObject* module) { DEF_ONE(FuncTorchVmapMode) DEF_ONE(FuncTorchGradWrapper) DEF_ONE(PythonDispatcher) + DEF_ONE(PreDispatch) DEF_ONE(Functionalize) DEF_ONE(AutocastCPU) DEF_ONE(AutocastXPU) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index e3e9ec4e01f..fe9685b23b5 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -10,7 +10,7 @@ import torch import torch.utils._pytree as pytree from torch.fx import Tracer, GraphModule from torch._subclasses.fake_tensor import FakeTensorMode -from torch._dispatch.python import enable_python_dispatcher +from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch import torch.fx as fx from torch.fx.passes.shape_prop import _extract_tensor_metadata from contextlib import contextmanager, nullcontext @@ -246,16 +246,7 @@ def fetch_tensor_proxy(tracer): HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter) -@contextlib.contextmanager -def inside_mode(proxy_mode): - old = proxy_mode.is_inside_mode - proxy_mode.is_inside_mode = True - try: - yield - finally: - proxy_mode.is_inside_mode = old - -def proxy_call(proxy_mode, func, pre_autograd, args, kwargs): +def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs): unrecognized_types = [] def can_handle_tensor(x): @@ -279,7 +270,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs): return r # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. - if not pre_autograd: + if not pre_dispatch: with proxy_mode: r = func.decompose(*args, **kwargs) if r is not NotImplemented: @@ -358,12 +349,6 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs): if func is torch.ops.aten.lift_fresh.default: func = torch.ops.aten.lift_fresh_copy.default - # See Note [Per-Dispatch-Key Modes Must Be Reentrant] - # If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks) - # then we only want it to trace out proxies the first time that we hit an op. - if proxy_mode.is_inside_mode: - return func(*args, **kwargs) - proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs, name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__)) @@ -379,8 +364,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs): else: args[0].proxy = proxy_out - with inside_mode(proxy_mode): - out = func(*args, **kwargs) + out = func(*args, **kwargs) # In some circumstances, we will be tracing in a situation where a tensor # is *statically* known to be a constant (currently, this only happens if @@ -428,12 +412,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs): else: constant = None - # See Note [Per-Dispatch-Key Modes Must Be Reentrant] - # If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks) - # then we only want it to trace out proxies the first time that we hit an op. - # In particular, track_tensor_tree can call detach(). - with inside_mode(proxy_mode): - track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) + track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) return out @@ -488,9 +467,9 @@ def dispatch_trace( return GraphModule(tracer.root, graph, name) -def wrap_key(f, tensors, tracer, pre_autograd: bool): +def wrap_key(f, tensors, tracer, pre_dispatch: bool): flat_tensors, tensors_spec = pytree.tree_flatten(tensors) - dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None + dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None @functools.wraps(f) def wrapped(*proxies): @@ -534,13 +513,13 @@ def set_original_aten_op(func): class ProxyTorchDispatchMode(TorchDispatchMode): - def __init__(self, tracer, tracing_mode, pre_autograd=False, _allow_fake_constant=False): - dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None + def __init__(self, tracer, tracing_mode, pre_dispatch=False, _allow_fake_constant=False): + dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None super().__init__(dk) self.tracer = tracer self.tracing_mode = tracing_mode self.enable_tracing = True - self.pre_autograd = pre_autograd + self.pre_dispatch = pre_dispatch self._allow_fake_constant = _allow_fake_constant self.is_inside_mode = False self.sym_mode = ProxySymDispatchMode(tracer) @@ -575,7 +554,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode): if func in [prim.device.default]: return func(*args, **kwargs) - return proxy_call(self, func, self.pre_autograd, args, kwargs) + return proxy_call(self, func, self.pre_dispatch, args, kwargs) class ProxySymDispatchMode(SymDispatchMode): @@ -709,7 +688,7 @@ def make_fx(f, tracing_mode="real", _allow_non_fake_inputs=False, *, - pre_autograd=False, + pre_dispatch=False, _allow_fake_constant=False): assert tracing_mode in ["real", "fake", "symbolic"] @@ -747,14 +726,17 @@ def make_fx(f, raise AssertionError(f"Unexpected tracing type: {tracing_mode}") python_dispatcher_mode: Any = nullcontext() + pre_dispatch_mode: Any = nullcontext() # pre-autograd tracing uses per-dispatch-key modes, # which requires the python dispatcher - if tracing_mode == "symbolic" or pre_autograd: + if tracing_mode == "symbolic" or pre_dispatch: python_dispatcher_mode = enable_python_dispatcher() + if pre_dispatch: + pre_dispatch_mode = enable_pre_dispatch() proxy_mode = ProxyTorchDispatchMode(fx_tracer, tracing_mode, - pre_autograd=pre_autograd, + pre_dispatch=pre_dispatch, _allow_fake_constant=_allow_fake_constant) arg_count = 0 @@ -795,9 +777,9 @@ def make_fx(f, # We also disable tracing by any other tensor proxy-based tracers except the current. The # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is # thus irrelevant to any external functional trace. - with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \ + with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, \ sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True): - t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_autograd), tracer=fx_tracer, concrete_args=tuple(phs)) + t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs)) # TODO: kind of a bad way to do it, should maybe figure out a better way if tracing_mode == "symbolic": diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index bd7673f1d4e..8412c645970 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -83,24 +83,16 @@ def _push_mode(mode, k: Optional[DispatchKey] = None): for key in ks: op._uncache_dispatch(key) push_mode_for_key(k, mode) - # Note [Per-Dispatch-Key Modes Must Be Reentrant] - # The idea here is that we are allowed to push modes onto any dispatch key's mode stack, but: - # (1) We **always** push the mode onto the python mode stack. Operators can have fallthrough - # kernels registered to any dispatch key, so we use the Python mode stack as a catchall, - # to guarantee that every op will be seen by our mode. - # (2) We expect the mode that you push to handle being re-entrant: If we end up invoking the mode - # at both the Autograd key and the Python key, nothing bad should happen. - # The main use case for this is pre-autograd tracing with TorchProxyDispatchMode. - _push_on_torch_dispatch_stack(mode) + else: + _push_on_torch_dispatch_stack(mode) def _pop_mode(k: Optional[DispatchKey] = None): - m = _pop_torch_dispatch_stack() if k is not None: from torch._ops import pop_mode_for_key - tmp = pop_mode_for_key(k) - assert m is tmp - return m + return pop_mode_for_key(k) + else: + return _pop_torch_dispatch_stack() @contextlib.contextmanager