mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Original PR broke internal
This reverts commit 5ed618132f.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103888
Approved by: https://github.com/albanD
This commit is contained in:
parent
8b418f197c
commit
c3c03e7cb8
12 changed files with 85 additions and 65 deletions
|
|
@ -107,6 +107,15 @@ void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySe
|
|||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
|
||||
}
|
||||
|
||||
// The PreDispatch key gets a no-op fallback that just redispatches.
|
||||
// The main way this key is used is that we can register a mode to it from python (e.g. TorchProxyDispatchMode, for pre_dispatch tracing)
|
||||
// Can't this be a fallthrough kernel, instead of a fallback that just no-ops and redispatches?
|
||||
// Unfortunately, no: we need a real kernel that is not a fallthrough, in order for the PythonDispatcher to interpose on it.
|
||||
// Alternatively, we could have hardcoded this kernel (in C++) to directly call in TorchProxyDispatchMode.
|
||||
// Doing that in C++ is a pain though, so it's done in python using the PythonDispatcher for convenience.
|
||||
void preDispatchFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
||||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PreDispatch), stack);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
|
@ -152,3 +161,7 @@ TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
|
|||
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, PreDispatch, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&preDispatchFallback>());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -180,6 +180,9 @@ const char* toString(DispatchKey t) {
|
|||
case DispatchKey::TESTING_ONLY_GenericMode:
|
||||
return "TESTING_ONLY_GenericMode";
|
||||
|
||||
case DispatchKey::PreDispatch:
|
||||
return "PreDispatch";
|
||||
|
||||
case DispatchKey::PythonDispatcher:
|
||||
return "PythonDispatcher";
|
||||
|
||||
|
|
@ -300,6 +303,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
|||
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
|
||||
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
|
||||
{"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
|
||||
{"PreDispatch", c10::DispatchKey::PreDispatch},
|
||||
|
||||
{"CPU", c10::DispatchKey::CPU},
|
||||
{"CUDA", c10::DispatchKey::CUDA},
|
||||
|
|
|
|||
|
|
@ -406,6 +406,12 @@ enum class DispatchKey : uint16_t {
|
|||
// for a usage example
|
||||
TESTING_ONLY_GenericMode,
|
||||
|
||||
// This key is used for pre-dispatch tracing in make_fx.
|
||||
// It has lower priority than the PythonDispatcher key
|
||||
// because we use the PythonDispatcher to intercept the key from python,
|
||||
// and avoid having to implement it in C++.
|
||||
PreDispatch,
|
||||
|
||||
// This is a bypass that allows you to skip running the C++ dispatcher
|
||||
// entirely
|
||||
PythonDispatcher,
|
||||
|
|
|
|||
|
|
@ -1828,9 +1828,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
inp = torch.randn(6, 7)
|
||||
self.assertEqual(gm(inp), f(inp))
|
||||
|
||||
# pre_autograd seems to violate new fake tensor invariants
|
||||
@unittest.expectedFailure
|
||||
def test_pre_autograd_simple(self):
|
||||
def test_pre_dispatch_simple(self):
|
||||
def f(x):
|
||||
y = torch.ones_like(x)
|
||||
return torch.matmul(x, y)
|
||||
|
|
@ -1839,7 +1837,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
f,
|
||||
torch.randn(5, 5),
|
||||
aten_graph=True,
|
||||
pre_autograd=True,
|
||||
pre_dispatch=True,
|
||||
tracing_mode="fake",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ class TestGenericProxyTensor(TestCase):
|
|||
r2 = f(*new_inps)
|
||||
self.assertEqual(r1, r2)
|
||||
|
||||
def test_pre_autograd_mode_stack(self):
|
||||
def test_pre_dispatch_mode_stack(self):
|
||||
def f(a):
|
||||
b = torch.ones(4, 4)
|
||||
return torch.matmul(a, b)
|
||||
|
|
@ -156,17 +156,28 @@ class TestGenericProxyTensor(TestCase):
|
|||
# This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
|
||||
# so our mode never sees it - it goes directly to the BackendSelect key.
|
||||
inp = torch.ones(4, 4)
|
||||
# Test that make_fx(pre_autograd=True) clears caches properly.
|
||||
# Test that make_fx(pre_dispatch=True) clears caches properly.
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
with enable_python_dispatcher():
|
||||
out1 = f(inp)
|
||||
fx_g = make_fx(f, pre_autograd=True)(inp)
|
||||
fx_g = make_fx(f, pre_dispatch=True)(inp)
|
||||
self.assertExpectedInline(fx_g.code.strip(), """\
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
|
||||
matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None
|
||||
return matmul""")
|
||||
|
||||
def test_pre_dispatch_linear(self):
|
||||
def f(a, b, c):
|
||||
return torch.nn.functional.linear(a, b, c)
|
||||
a = torch.ones(4, 4)
|
||||
b = torch.ones(4, 4)
|
||||
c = torch.ones(4)
|
||||
fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
|
||||
out1 = f(a, b, c)
|
||||
out2 = fx_g(a, b, c)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
|
||||
def test_make_fx_simple(self):
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -1191,6 +1191,11 @@ class _DisablePythonDispatcher:
|
|||
def __enter__(self): ...
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
class _EnablePreDispatch:
|
||||
def __init__(self): ...
|
||||
def __enter__(self): ...
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
class _DisableFuncTorch:
|
||||
def __init__(self): ...
|
||||
def __enter__(self): ...
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@ import itertools
|
|||
from typing import Iterator
|
||||
import torch._ops
|
||||
|
||||
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
|
||||
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher', 'enable_pre_dispatch']
|
||||
|
||||
no_python_dispatcher = torch._C._DisablePythonDispatcher
|
||||
enable_python_dispatcher = torch._C._EnablePythonDispatcher
|
||||
enable_pre_dispatch = torch._C._EnablePreDispatch
|
||||
|
||||
CROSSREF_FUNCTIONALIZE = False
|
||||
|
||||
|
|
|
|||
|
|
@ -791,7 +791,7 @@ def export(
|
|||
f: Callable[..., Any],
|
||||
*args,
|
||||
aten_graph: bool = False,
|
||||
pre_autograd: bool = False,
|
||||
pre_dispatch: bool = False,
|
||||
decomposition_table: Optional[
|
||||
Dict[torch._ops.OpOverload, Callable[..., Any]]
|
||||
] = None,
|
||||
|
|
@ -812,9 +812,10 @@ def export(
|
|||
aten_graph (bool): If True, exports a graph with ATen operators.
|
||||
If False, exports a graph with Python operators. Default is False.
|
||||
|
||||
pre_autograd (bool): If True, exports a graph with ATen operators,
|
||||
but before autograd has run. This can be useful if you want to apply further tranformations
|
||||
on a graph before running it through autograd.
|
||||
pre_dispatch (bool): If True, exports a graph with ATen operators,
|
||||
but before any logic in the PyTorch dispatcher has run.
|
||||
This can be useful if you want to apply further tranformations on a graph before running it
|
||||
through autograd, autocast, or any other functionalities that are integrated into the dispatcher.
|
||||
This flag is only valid if aten_graph=True is set.
|
||||
Default is False.
|
||||
|
||||
|
|
@ -848,8 +849,8 @@ def export(
|
|||
assert (
|
||||
aten_graph
|
||||
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
|
||||
if pre_autograd:
|
||||
assert aten_graph, "pre_autograd=True can only be used when aten_graph=True"
|
||||
if pre_dispatch:
|
||||
assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
|
||||
f = innermost_fn(f)
|
||||
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
|
||||
original_signature = inspect.signature(call_to_inspect)
|
||||
|
|
@ -1018,7 +1019,7 @@ def export(
|
|||
decomposition_table=decomposition_table,
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=True,
|
||||
pre_autograd=pre_autograd,
|
||||
pre_dispatch=pre_dispatch,
|
||||
_allow_fake_constant=_allow_fake_constant,
|
||||
)(*example_fake_inputs)
|
||||
except CondOpArgsMismatchError as e:
|
||||
|
|
|
|||
|
|
@ -98,6 +98,11 @@ struct EnablePythonDispatcher {
|
|||
c10::impl::PyInterpreter* old_;
|
||||
};
|
||||
|
||||
struct EnablePreDispatch {
|
||||
EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
|
||||
c10::impl::IncludeDispatchKeyGuard guard_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
|
|
@ -419,6 +424,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
|||
_C_m, "_EnablePythonDispatcher");
|
||||
py_context_manager<c10::impl::DisablePythonDispatcher>(
|
||||
_C_m, "_DisablePythonDispatcher");
|
||||
py_context_manager<EnablePreDispatch>(_C_m, "_EnablePreDispatch");
|
||||
py_context_manager_DEPRECATED<DisableFuncTorch>(_C_m, "_DisableFuncTorch");
|
||||
py_context_manager_DEPRECATED<MultithreadingEnabled, bool>(
|
||||
_C_m, "_MultithreadingEnabled");
|
||||
|
|
|
|||
|
|
@ -561,6 +561,7 @@ void initDispatchBindings(PyObject* module) {
|
|||
DEF_ONE(FuncTorchVmapMode)
|
||||
DEF_ONE(FuncTorchGradWrapper)
|
||||
DEF_ONE(PythonDispatcher)
|
||||
DEF_ONE(PreDispatch)
|
||||
DEF_ONE(Functionalize)
|
||||
DEF_ONE(AutocastCPU)
|
||||
DEF_ONE(AutocastXPU)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import torch
|
|||
import torch.utils._pytree as pytree
|
||||
from torch.fx import Tracer, GraphModule
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch
|
||||
import torch.fx as fx
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
|
@ -246,16 +246,7 @@ def fetch_tensor_proxy(tracer):
|
|||
|
||||
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inside_mode(proxy_mode):
|
||||
old = proxy_mode.is_inside_mode
|
||||
proxy_mode.is_inside_mode = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
proxy_mode.is_inside_mode = old
|
||||
|
||||
def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
||||
def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
|
||||
unrecognized_types = []
|
||||
|
||||
def can_handle_tensor(x):
|
||||
|
|
@ -279,7 +270,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
|||
return r
|
||||
|
||||
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
|
||||
if not pre_autograd:
|
||||
if not pre_dispatch:
|
||||
with proxy_mode:
|
||||
r = func.decompose(*args, **kwargs)
|
||||
if r is not NotImplemented:
|
||||
|
|
@ -358,12 +349,6 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
|||
if func is torch.ops.aten.lift_fresh.default:
|
||||
func = torch.ops.aten.lift_fresh_copy.default
|
||||
|
||||
# See Note [Per-Dispatch-Key Modes Must Be Reentrant]
|
||||
# If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks)
|
||||
# then we only want it to trace out proxies the first time that we hit an op.
|
||||
if proxy_mode.is_inside_mode:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
|
||||
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
|
||||
|
||||
|
|
@ -379,8 +364,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
|||
else:
|
||||
args[0].proxy = proxy_out
|
||||
|
||||
with inside_mode(proxy_mode):
|
||||
out = func(*args, **kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# In some circumstances, we will be tracing in a situation where a tensor
|
||||
# is *statically* known to be a constant (currently, this only happens if
|
||||
|
|
@ -428,12 +412,7 @@ def proxy_call(proxy_mode, func, pre_autograd, args, kwargs):
|
|||
else:
|
||||
constant = None
|
||||
|
||||
# See Note [Per-Dispatch-Key Modes Must Be Reentrant]
|
||||
# If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks)
|
||||
# then we only want it to trace out proxies the first time that we hit an op.
|
||||
# In particular, track_tensor_tree can call detach().
|
||||
with inside_mode(proxy_mode):
|
||||
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
|
||||
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
|
||||
return out
|
||||
|
||||
|
||||
|
|
@ -488,9 +467,9 @@ def dispatch_trace(
|
|||
return GraphModule(tracer.root, graph, name)
|
||||
|
||||
|
||||
def wrap_key(f, tensors, tracer, pre_autograd: bool):
|
||||
def wrap_key(f, tensors, tracer, pre_dispatch: bool):
|
||||
flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
|
||||
dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None
|
||||
dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*proxies):
|
||||
|
|
@ -534,13 +513,13 @@ def set_original_aten_op(func):
|
|||
|
||||
|
||||
class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
def __init__(self, tracer, tracing_mode, pre_autograd=False, _allow_fake_constant=False):
|
||||
dk = torch._C.DispatchKey.AutogradFunctionality if pre_autograd else None
|
||||
def __init__(self, tracer, tracing_mode, pre_dispatch=False, _allow_fake_constant=False):
|
||||
dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
|
||||
super().__init__(dk)
|
||||
self.tracer = tracer
|
||||
self.tracing_mode = tracing_mode
|
||||
self.enable_tracing = True
|
||||
self.pre_autograd = pre_autograd
|
||||
self.pre_dispatch = pre_dispatch
|
||||
self._allow_fake_constant = _allow_fake_constant
|
||||
self.is_inside_mode = False
|
||||
self.sym_mode = ProxySymDispatchMode(tracer)
|
||||
|
|
@ -575,7 +554,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
if func in [prim.device.default]:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return proxy_call(self, func, self.pre_autograd, args, kwargs)
|
||||
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
|
||||
|
||||
|
||||
class ProxySymDispatchMode(SymDispatchMode):
|
||||
|
|
@ -709,7 +688,7 @@ def make_fx(f,
|
|||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=False,
|
||||
*,
|
||||
pre_autograd=False,
|
||||
pre_dispatch=False,
|
||||
_allow_fake_constant=False):
|
||||
assert tracing_mode in ["real", "fake", "symbolic"]
|
||||
|
||||
|
|
@ -747,14 +726,17 @@ def make_fx(f,
|
|||
raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
|
||||
|
||||
python_dispatcher_mode: Any = nullcontext()
|
||||
pre_dispatch_mode: Any = nullcontext()
|
||||
# pre-autograd tracing uses per-dispatch-key modes,
|
||||
# which requires the python dispatcher
|
||||
if tracing_mode == "symbolic" or pre_autograd:
|
||||
if tracing_mode == "symbolic" or pre_dispatch:
|
||||
python_dispatcher_mode = enable_python_dispatcher()
|
||||
if pre_dispatch:
|
||||
pre_dispatch_mode = enable_pre_dispatch()
|
||||
|
||||
proxy_mode = ProxyTorchDispatchMode(fx_tracer,
|
||||
tracing_mode,
|
||||
pre_autograd=pre_autograd,
|
||||
pre_dispatch=pre_dispatch,
|
||||
_allow_fake_constant=_allow_fake_constant)
|
||||
|
||||
arg_count = 0
|
||||
|
|
@ -795,9 +777,9 @@ def make_fx(f,
|
|||
# We also disable tracing by any other tensor proxy-based tracers except the current. The
|
||||
# purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
|
||||
# thus irrelevant to any external functional trace.
|
||||
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \
|
||||
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, \
|
||||
sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True):
|
||||
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_autograd), tracer=fx_tracer, concrete_args=tuple(phs))
|
||||
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
|
||||
|
||||
# TODO: kind of a bad way to do it, should maybe figure out a better way
|
||||
if tracing_mode == "symbolic":
|
||||
|
|
|
|||
|
|
@ -83,24 +83,16 @@ def _push_mode(mode, k: Optional[DispatchKey] = None):
|
|||
for key in ks:
|
||||
op._uncache_dispatch(key)
|
||||
push_mode_for_key(k, mode)
|
||||
# Note [Per-Dispatch-Key Modes Must Be Reentrant]
|
||||
# The idea here is that we are allowed to push modes onto any dispatch key's mode stack, but:
|
||||
# (1) We **always** push the mode onto the python mode stack. Operators can have fallthrough
|
||||
# kernels registered to any dispatch key, so we use the Python mode stack as a catchall,
|
||||
# to guarantee that every op will be seen by our mode.
|
||||
# (2) We expect the mode that you push to handle being re-entrant: If we end up invoking the mode
|
||||
# at both the Autograd key and the Python key, nothing bad should happen.
|
||||
# The main use case for this is pre-autograd tracing with TorchProxyDispatchMode.
|
||||
_push_on_torch_dispatch_stack(mode)
|
||||
else:
|
||||
_push_on_torch_dispatch_stack(mode)
|
||||
|
||||
|
||||
def _pop_mode(k: Optional[DispatchKey] = None):
|
||||
m = _pop_torch_dispatch_stack()
|
||||
if k is not None:
|
||||
from torch._ops import pop_mode_for_key
|
||||
tmp = pop_mode_for_key(k)
|
||||
assert m is tmp
|
||||
return m
|
||||
return pop_mode_for_key(k)
|
||||
else:
|
||||
return _pop_torch_dispatch_stack()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
|
|||
Loading…
Reference in a new issue