Original PR broke internal

This reverts commit 5ed618132f.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103888
Approved by: https://github.com/albanD
This commit is contained in:
Brian Hirsh 2023-06-21 00:43:00 +00:00 committed by PyTorch MergeBot
parent 8b418f197c
commit c3c03e7cb8
12 changed files with 85 additions and 65 deletions

View file

@ -107,6 +107,15 @@ void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySe
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
}
// The PreDispatch key gets a no-op fallback that just redispatches.
// The main way this key is used is that we can register a mode to it from python (e.g. TorchProxyDispatchMode, for pre_dispatch tracing)
// Can't this be a fallthrough kernel, instead of a fallback that just no-ops and redispatches?
// Unfortunately, no: we need a real kernel that is not a fallthrough, in order for the PythonDispatcher to interpose on it.
// Alternatively, we could have hardcoded this kernel (in C++) to directly call in TorchProxyDispatchMode.
// Doing that in C++ is a pain though, so it's done in python using the PythonDispatcher for convenience.
void preDispatchFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PreDispatch), stack);
}
} // anonymous namespace
@ -152,3 +161,7 @@ TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
}
TORCH_LIBRARY_IMPL(_, PreDispatch, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&preDispatchFallback>());
}

View file

@ -180,6 +180,9 @@ const char* toString(DispatchKey t) {
case DispatchKey::TESTING_ONLY_GenericMode:
return "TESTING_ONLY_GenericMode";
case DispatchKey::PreDispatch:
return "PreDispatch";
case DispatchKey::PythonDispatcher:
return "PythonDispatcher";
@ -300,6 +303,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
{"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
{"PreDispatch", c10::DispatchKey::PreDispatch},
{"CPU", c10::DispatchKey::CPU},
{"CUDA", c10::DispatchKey::CUDA},

View file

@ -406,6 +406,12 @@ enum class DispatchKey : uint16_t {
// for a usage example
TESTING_ONLY_GenericMode,
// This key is used for pre-dispatch tracing in make_fx.
// It has lower priority than the PythonDispatcher key
// because we use the PythonDispatcher to intercept the key from python,
// and avoid having to implement it in C++.
PreDispatch,
// This is a bypass that allows you to skip running the C++ dispatcher
// entirely
PythonDispatcher,

View file

@ -1828,9 +1828,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
inp = torch.randn(6, 7)
self.assertEqual(gm(inp), f(inp))
# pre_autograd seems to violate new fake tensor invariants
@unittest.expectedFailure
def test_pre_autograd_simple(self):
def test_pre_dispatch_simple(self):
def f(x):
y = torch.ones_like(x)
return torch.matmul(x, y)
@ -1839,7 +1837,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
f,
torch.randn(5, 5),
aten_graph=True,
pre_autograd=True,
pre_dispatch=True,
tracing_mode="fake",
)

View file

@ -147,7 +147,7 @@ class TestGenericProxyTensor(TestCase):
r2 = f(*new_inps)
self.assertEqual(r1, r2)
def test_pre_autograd_mode_stack(self):
def test_pre_dispatch_mode_stack(self):
def f(a):
b = torch.ones(4, 4)
return torch.matmul(a, b)
@ -156,17 +156,28 @@ class TestGenericProxyTensor(TestCase):
# This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
# so our mode never sees it - it goes directly to the BackendSelect key.
inp = torch.ones(4, 4)
# Test that make_fx(pre_autograd=True) clears caches properly.
# Test that make_fx(pre_dispatch=True) clears caches properly.
from torch._dispatch.python import enable_python_dispatcher
with enable_python_dispatcher():
out1 = f(inp)
fx_g = make_fx(f, pre_autograd=True)(inp)
fx_g = make_fx(f, pre_dispatch=True)(inp)
self.assertExpectedInline(fx_g.code.strip(), """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None
return matmul""")
def test_pre_dispatch_linear(self):
def f(a, b, c):
return torch.nn.functional.linear(a, b, c)
a = torch.ones(4, 4)
b = torch.ones(4, 4)
c = torch.ones(4)
fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
out1 = f(a, b, c)
out2 = fx_g(a, b, c)
self.assertEqual(out1, out2)
def test_make_fx_simple(self):
def f(x):

View file

@ -1191,6 +1191,11 @@ class _DisablePythonDispatcher:
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
class _EnablePreDispatch:
def __init__(self): ...
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
class _DisableFuncTorch:
def __init__(self): ...
def __enter__(self): ...

View file

@ -7,10 +7,11 @@ import itertools
from typing import Iterator
import torch._ops
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher', 'enable_pre_dispatch']
no_python_dispatcher = torch._C._DisablePythonDispatcher
enable_python_dispatcher = torch._C._EnablePythonDispatcher
enable_pre_dispatch = torch._C._EnablePreDispatch
CROSSREF_FUNCTIONALIZE = False

View file

@ -791,7 +791,7 @@ def export(
f: Callable[..., Any],
*args,
aten_graph: bool = False,
pre_autograd: bool = False,
pre_dispatch: bool = False,
decomposition_table: Optional[
Dict[torch._ops.OpOverload, Callable[..., Any]]
] = None,
@ -812,9 +812,10 @@ def export(
aten_graph (bool): If True, exports a graph with ATen operators.
If False, exports a graph with Python operators. Default is False.
pre_autograd (bool): If True, exports a graph with ATen operators,
but before autograd has run. This can be useful if you want to apply further tranformations
on a graph before running it through autograd.
pre_dispatch (bool): If True, exports a graph with ATen operators,
but before any logic in the PyTorch dispatcher has run.
This can be useful if you want to apply further tranformations on a graph before running it
through autograd, autocast, or any other functionalities that are integrated into the dispatcher.
This flag is only valid if aten_graph=True is set.
Default is False.
@ -848,8 +849,8 @@ def export(
assert (
aten_graph
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
if pre_autograd:
assert aten_graph, "pre_autograd=True can only be used when aten_graph=True"
if pre_dispatch:
assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
f = innermost_fn(f)
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
original_signature = inspect.signature(call_to_inspect)
@ -1018,7 +1019,7 @@ def export(
decomposition_table=decomposition_table,
tracing_mode="real",
_allow_non_fake_inputs=True,
pre_autograd=pre_autograd,
pre_dispatch=pre_dispatch,
_allow_fake_constant=_allow_fake_constant,
)(*example_fake_inputs)
except CondOpArgsMismatchError as e:

View file

@ -98,6 +98,11 @@ struct EnablePythonDispatcher {
c10::impl::PyInterpreter* old_;
};
struct EnablePreDispatch {
EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
c10::impl::IncludeDispatchKeyGuard guard_;
};
} // namespace
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
@ -419,6 +424,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
_C_m, "_EnablePythonDispatcher");
py_context_manager<c10::impl::DisablePythonDispatcher>(
_C_m, "_DisablePythonDispatcher");
py_context_manager<EnablePreDispatch>(_C_m, "_EnablePreDispatch");
py_context_manager_DEPRECATED<DisableFuncTorch>(_C_m, "_DisableFuncTorch");
py_context_manager_DEPRECATED<MultithreadingEnabled, bool>(
_C_m, "_MultithreadingEnabled");

View file

@ -561,6 +561,7 @@ void initDispatchBindings(PyObject* module) {
DEF_ONE(FuncTorchVmapMode)
DEF_ONE(FuncTorchGradWrapper)
DEF_ONE(PythonDispatcher)
DEF_ONE(PreDispatch)
DEF_ONE(Functionalize)
DEF_ONE(AutocastCPU)
DEF_ONE(AutocastXPU)

View file

@ -10,7 +10,7 @@ import torch
import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._dispatch.python import enable_python_dispatcher
from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch
import torch.fx as fx
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager, nullcontext
@ -246,16 +246,7 @@ def fetch_tensor_proxy(tracer):
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
@contextlib.contextmanager
def inside_mode(proxy_mode):
old = proxy_mode.is_inside_mode
proxy_mode.is_inside_mode = True
try:
yield
finally:
proxy_mode.is_inside_mode = old
def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
unrecognized_types = []
def can_handle_tensor(x):
@ -279,7 +270,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
return r
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
if not pre_autograd:
if not pre_dispatch:
with proxy_mode:
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:
@ -358,12 +349,6 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
if func is torch.ops.aten.lift_fresh.default:
func = torch.ops.aten.lift_fresh_copy.default
# See Note [Per-Dispatch-Key Modes Must Be Reentrant]
# If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks)
# then we only want it to trace out proxies the first time that we hit an op.
if proxy_mode.is_inside_mode:
return func(*args, **kwargs)
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
@ -379,8 +364,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
else:
args[0].proxy = proxy_out
with inside_mode(proxy_mode):
out = func(*args, **kwargs)
out = func(*args, **kwargs)
# In some circumstances, we will be tracing in a situation where a tensor
# is *statically* known to be a constant (currently, this only happens if
@ -428,12 +412,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
else:
constant = None
# See Note [Per-Dispatch-Key Modes Must Be Reentrant]
# If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks)
# then we only want it to trace out proxies the first time that we hit an op.
# In particular, track_tensor_tree can call detach().
with inside_mode(proxy_mode):
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
return out
@ -488,9 +467,9 @@ def dispatch_trace(
return GraphModule(tracer.root, graph, name)
def wrap_key(f, tensors, tracer, pre_autograd: bool):
def wrap_key(f, tensors, tracer, pre_dispatch: bool):
flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None
dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
@functools.wraps(f)
def wrapped(*proxies):
@ -534,13 +513,13 @@ def set_original_aten_op(func):
class ProxyTorchDispatchMode(TorchDispatchMode):
def __init__(self, tracer, tracing_mode, pre_autograd=False, _allow_fake_constant=False):
dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None
def __init__(self, tracer, tracing_mode, pre_dispatch=False, _allow_fake_constant=False):
dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
super().__init__(dk)
self.tracer = tracer
self.tracing_mode = tracing_mode
self.enable_tracing = True
self.pre_autograd = pre_autograd
self.pre_dispatch = pre_dispatch
self._allow_fake_constant = _allow_fake_constant
self.is_inside_mode = False
self.sym_mode = ProxySymDispatchMode(tracer)
@ -575,7 +554,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
if func in [prim.device.default]:
return func(*args, **kwargs)
return proxy_call(self, func, self.pre_autograd, args, kwargs)
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
class ProxySymDispatchMode(SymDispatchMode):
@ -709,7 +688,7 @@ def make_fx(f,
tracing_mode="real",
_allow_non_fake_inputs=False,
*,
pre_autograd=False,
pre_dispatch=False,
_allow_fake_constant=False):
assert tracing_mode in ["real", "fake", "symbolic"]
@ -747,14 +726,17 @@ def make_fx(f,
raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
python_dispatcher_mode: Any = nullcontext()
pre_dispatch_mode: Any = nullcontext()
# pre-autograd tracing uses per-dispatch-key modes,
# which requires the python dispatcher
if tracing_mode == "symbolic" or pre_autograd:
if tracing_mode == "symbolic" or pre_dispatch:
python_dispatcher_mode = enable_python_dispatcher()
if pre_dispatch:
pre_dispatch_mode = enable_pre_dispatch()
proxy_mode = ProxyTorchDispatchMode(fx_tracer,
tracing_mode,
pre_autograd=pre_autograd,
pre_dispatch=pre_dispatch,
_allow_fake_constant=_allow_fake_constant)
arg_count = 0
@ -795,9 +777,9 @@ def make_fx(f,
# We also disable tracing by any other tensor proxy-based tracers except the current. The
# purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
# thus irrelevant to any external functional trace.
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, \
sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True):
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_autograd), tracer=fx_tracer, concrete_args=tuple(phs))
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
# TODO: kind of a bad way to do it, should maybe figure out a better way
if tracing_mode == "symbolic":

View file

@ -83,24 +83,16 @@ def _push_mode(mode, k: Optional[DispatchKey] = None):
for key in ks:
op._uncache_dispatch(key)
push_mode_for_key(k, mode)
# Note [Per-Dispatch-Key Modes Must Be Reentrant]
# The idea here is that we are allowed to push modes onto any dispatch key's mode stack, but:
# (1) We **always** push the mode onto the python mode stack. Operators can have fallthrough
# kernels registered to any dispatch key, so we use the Python mode stack as a catchall,
# to guarantee that every op will be seen by our mode.
# (2) We expect the mode that you push to handle being re-entrant: If we end up invoking the mode
# at both the Autograd key and the Python key, nothing bad should happen.
# The main use case for this is pre-autograd tracing with TorchProxyDispatchMode.
_push_on_torch_dispatch_stack(mode)
else:
_push_on_torch_dispatch_stack(mode)
def _pop_mode(k: Optional[DispatchKey] = None):
m = _pop_torch_dispatch_stack()
if k is not None:
from torch._ops import pop_mode_for_key
tmp = pop_mode_for_key(k)
assert m is tmp
return m
return pop_mode_for_key(k)
else:
return _pop_torch_dispatch_stack()
@contextlib.contextmanager