[RELAND] Switch default behavoir of export IR to be predispatch (#125860)

This PR switches export IR from aot-dispatch to pre-dispatch IR.

**What is pre-dispatch IR and why should you care?**

Currently the default IR returned by torch.export can contain only functional ATen operators after ALL pytorch dispatcher decompositions (for example, CompositeImplicitAutograd) run.

In contrast, pre-dispatch IR refers to an IR that can contain all functional ATen operators (i.e., not just from the core subset), before any decomposition happens, as well as operators that manipulate autograd state. Pre-dispatch IR closely resembles eager PyTorch computation, but is still functional and serializable by torch.export. As a result:

You can train the pre-dispatch IR in eager mode as the IR contains necessary information for the autograd engine to automatically generate a backward graph.
You can write sound graph transformations more easily as the IR is functional.
Since it is an ATen IR, it is still normalized. For example, torch.add has multiple overloads, but aten.add.Tensor is unique in this IR.
If you want to get the core aten IR out of torch.export, you will need to:
```
ep = torch.export.export(M(), inputs)
ep_for_core_aten = ep.run_decompositions()
```

Differential Revision: [D57172986](https://our.internmc.facebook.com/intern/diff/D57172986)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125860
Approved by: https://github.com/zhxchen17
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2024-05-09 14:03:37 -07:00 committed by PyTorch MergeBot
parent 4996a3fda3
commit d7fe3c4123
12 changed files with 66 additions and 212 deletions

View file

@ -72,10 +72,10 @@ class TensorParallelTest(DTensorTestBase):
inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
with torch.no_grad():
res = model(*inputs)
exported_program = torch.export.export(
model,
inputs,
)
exported_program = torch.export.export(
model,
inputs,
).run_decompositions()
tp_exported_program = tensor_parallel_transformation(
exported_program,
self.rank,
@ -110,10 +110,10 @@ class TensorParallelTest(DTensorTestBase):
with torch.inference_mode():
res = model(*inputs)
exported_program = torch.export.export(
model,
inputs,
)
exported_program = torch.export.export(
model,
inputs,
).run_decompositions()
tp_exported_program = tensor_parallel_transformation(
exported_program,
self.rank,
@ -146,10 +146,10 @@ class TensorParallelTest(DTensorTestBase):
with torch.inference_mode():
res = model(*inputs)
exported_program = torch.export.export(
model,
inputs,
)
exported_program = torch.export.export(
model,
inputs,
).run_decompositions()
tp_exported_program = tensor_parallel_transformation(
exported_program,
self.rank,

View file

@ -47,9 +47,9 @@ def forward(self, b_submodule_buffer1, x):
sin = torch.ops.aten.sin.default(x)
strict_graph_0 = self.strict_graph_0
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
getitem = strict_mode[0]; strict_mode = None
getitem_2 = strict_mode[0]; strict_mode = None
add = torch.ops.aten.add.Tensor(x, 3); x = None
return (getitem, add)""",
return (getitem_2, add)""",
)
self.assertExpectedInline(

View file

@ -576,7 +576,6 @@ class TestExport(TestCase):
)
# Predispatch has different expected results
@testing.expectedFailureSerDerPreDispatch
def test_torch_fn(self):
class M1(torch.nn.Module):
def __init__(self):
@ -591,7 +590,7 @@ class TestExport(TestCase):
x = x + x
return x
ep1 = export(M1(), (torch.randn(3, 3),))
ep1 = export(M1(), (torch.randn(3, 3),)).run_decompositions()
expected_result = [
("linear_1", "builtin_function_or_method.linear"),
("linear_1", "builtin_function_or_method.linear"),
@ -616,7 +615,9 @@ class TestExport(TestCase):
x = torch.add(x, x)
return x
ep2 = export(M2(), (torch.randn(3, 3), torch.randn(3, 3), torch.randn(3)))
ep2 = export(
M2(), (torch.randn(3, 3), torch.randn(3, 3), torch.randn(3))
).run_decompositions()
expected_result = [
("linear_1", "builtin_function_or_method.linear"),
("linear_1", "builtin_function_or_method.linear"),
@ -3061,8 +3062,6 @@ def forward(self, x):
with self.assertRaisesRegex(ValueError, "Trying to flatten user inputs"):
exported_program.module()(torch.rand(2, 3), torch.rand(2, 3))
@testing.expectedFailureSerDerPreDispatch # linear shouldn't decompose
@testing.expectedFailurePreDispatchRunDecomp # no action needed here
def test_export_decomps_simple(self):
class M(torch.nn.Module):
def __init__(self):
@ -3077,9 +3076,6 @@ def forward(self, x):
ep = export(m, inp)
state_dict = ep.state_dict
FileCheck().check_count("torch.ops.aten.t.default", 1, exactly=True).run(
ep.graph_module.code
)
self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
core_aten_ep = ep.run_decompositions()
@ -3786,8 +3782,8 @@ def forward(self, x):
inp = (torch.randn(4, 4),)
mod = Foo()
ep_strict = torch.export.export(mod, inp)
ep_non_strict = torch.export.export(mod, inp, strict=False)
ep_strict = torch.export.export(mod, inp).run_decompositions()
ep_non_strict = torch.export.export(mod, inp, strict=False).run_decompositions()
gm_unflat_non_strict = unflatten(ep_non_strict)
self.assertTrue(hasattr(gm_unflat_non_strict, "bar"))
@ -3804,8 +3800,8 @@ graph():
%x : [num_users=1] = placeholder[target=x]
%weight : [num_users=1] = get_attr[target=weight]
%bias : [num_users=1] = get_attr[target=bias]
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%weight,), kwargs = {})
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %x, %t), kwargs = {})
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%weight, [1, 0]), kwargs = {})
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %x, %permute), kwargs = {})
return addmm""",
)
@ -3876,9 +3872,8 @@ graph():
%x : [num_users=1] = placeholder[target=x]
%weight : [num_users=1] = get_attr[target=weight]
%bias : [num_users=1] = get_attr[target=bias]
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%weight,), kwargs = {})
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %x, %t), kwargs = {})
return addmm""",
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %weight, %bias), kwargs = {})
return linear""",
)
self.assertExpectedInline(
str(gm_unflat_non_strict.bar_different.leaf.linear.graph).strip(),
@ -3887,9 +3882,8 @@ graph():
%add_2 : [num_users=1] = placeholder[target=add_2]
%weight : [num_users=1] = get_attr[target=weight]
%bias : [num_users=1] = get_attr[target=bias]
%t_1 : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%weight,), kwargs = {})
%addmm_1 : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %add_2, %t_1), kwargs = {})
return addmm_1""",
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%add_2, %weight, %bias), kwargs = {})
return linear_1""",
)
gm_flat_non_strict = ep_non_strict.module()
@ -4434,7 +4428,7 @@ def forward(self, x):
inps = (torch.ones(5),)
ep = torch.export.export(M(), inps)
ep = torch.export.export(M(), inps).run_decompositions()
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
@ -4952,7 +4946,9 @@ class TestOneOffModelExportResult(TestCase):
k = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda")
v = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda")
ep = torch.export.export(ScaledDotProductAttention(), (q, k, v))
ep = torch.export.export(
ScaledDotProductAttention(), (q, k, v)
).run_decompositions()
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\

View file

@ -1,157 +0,0 @@
# Owner(s): ["oncall: export"]
import unittest
import torch
import torch._dynamo as torchdynamo
from torch.export import export
from torch.testing._internal.common_utils import run_tests, TestCase
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestSafeguard(TestCase):
# If the autograd state doesn't change, dynamo eliminates autograd state manager op and later export can succeed.
# Otherwise, autograd can be preserved in the produced gragh, and export will fail.
def test_global_autograd(self):
class F1(torch.nn.Module):
def forward(self, a):
with torch.no_grad():
b = a + a
return b
f1 = F1()
class F2(torch.nn.Module):
def forward(self, a):
with torch.enable_grad():
b = a + a
return b
f2 = F2()
class F3(torch.nn.Module):
def forward(self, a):
with torch.set_grad_enabled(False):
b = a + a
return b
f3 = F3()
class F4(torch.nn.Module):
def forward(self, a):
with torch.set_grad_enabled(True):
b = a + a
return b
f4 = F4()
a = torch.randn(10)
with torch.no_grad():
export(f1, (a,))
export(f2, (a,))
export(f3, (a,))
export(f4, (a,))
with torch.enable_grad():
export(f2, (a,))
export(f4, (a,))
with self.assertRaisesRegex(
RuntimeError, "Encountered autograd state manager op.*"
):
export(f1, (a,))
with self.assertRaisesRegex(
RuntimeError, "Encountered autograd state manager op.*"
):
export(f3, (a,))
def test_tensor_autograd(self):
# dynamo errors when Tensor.requires_grad_ change the autograd state
class F1(torch.nn.Module):
def forward(self, a):
a.requires_grad_(True)
b = a + a
return b
f1 = F1()
# dynamo errors when Tensor.requires_grad_ change the autograd state
class F2(torch.nn.Module):
def forward(self, a):
a.requires_grad_(False)
b = a + a
return b
f2 = F2()
# dynamo always errors on Tensor.requires_grad
class F3(torch.nn.Module):
def forward(self, a):
a.requires_grad = False
b = a + a
return b
f3 = F3()
export(f1, (torch.randn(10, requires_grad=True),))
export(f2, (torch.randn(10, requires_grad=False),))
with self.assertRaises(RuntimeError):
export(f1, (torch.randn(10, requires_grad=False),))
with self.assertRaises(RuntimeError):
export(f2, (torch.randn(10, requires_grad=True),))
with self.assertRaises(RuntimeError):
export(f3, (torch.randn(10, requires_grad=False),))
def test_global_autograd_exempt_predispatch(self):
class F1(torch.nn.Module):
def forward(self, a):
with torch.no_grad():
b = a + a
return b
f1 = F1()
class F2(torch.nn.Module):
def forward(self, a):
with torch.enable_grad():
b = a + a
return b
f2 = F2()
class F3(torch.nn.Module):
def forward(self, a):
with torch.set_grad_enabled(False):
b = a + a
return b
f3 = F3()
class F4(torch.nn.Module):
def forward(self, a):
with torch.set_grad_enabled(True):
b = a + a
return b
f4 = F4()
a = torch.randn(10)
from torch.export._trace import _export
with torch.no_grad():
_export(f1, (a,), pre_dispatch=True)
_export(f2, (a,), pre_dispatch=True)
_export(f3, (a,), pre_dispatch=True)
_export(f4, (a,), pre_dispatch=True)
with torch.enable_grad():
_export(f1, (a,), pre_dispatch=True)
_export(f2, (a,), pre_dispatch=True)
_export(f3, (a,), pre_dispatch=True)
_export(f4, (a,), pre_dispatch=True)
if __name__ == "__main__":
run_tests()

View file

@ -183,7 +183,7 @@ class TestSerialize(TestCase):
torch.ones([512]),
torch.ones([512]),
),
)
).run_decompositions()
serialized = ExportedProgramSerializer().serialize(exported_module)
node = serialized.exported_program.graph_module.graph.nodes[-1]

View file

@ -1873,7 +1873,20 @@ def _run_test_output_match(
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
try:
model = torch.export.export(model, inputs)
# TODO (tugsbayasgalan) Migrate to pre-dispatch IR
# BUG1: python test/onnx/test_fx_op_consistency.py -k test_output_match_triu_cpu_int32
# has unexpected success, but don't know how to remove from xfail list
# BUG2: User output to_sparse is not in the correct order or is not found in the
# exported program's user_output list (https://github.com/pytorch/pytorch/issues/124328)
# python test/onnx/test_fx_op_consistency.py -k test_output_match_to_sparse_cpu_float32
# BUG3: [ShapeInferenceError] Inference error(s): (op_type:aten_view, node name: aten_view_4):
# [ShapeInferenceError]
# Inference error(s): (op_type:Reshape, node name: n1): [ShapeInferenceError] Invalid position of 0.
# python test/onnx/test_fx_op_consistency.py -k test_output_match_stack_cpu_int32
from torch.export import _trace
model = _trace._export(model, inputs, pre_dispatch=False)
except AssertionError as e:
# NOTE: avoid fake_mode detection bug in torch.export.export
pytest.xfail(

View file

@ -11,10 +11,12 @@ import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch import Tensor
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import lazy_format_graph_code
from torch._logging import getArtifactLogger, trace_structured
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import _detect_infra_mode
from .. import config
from .functional_utils import (
@ -116,9 +118,7 @@ def aot_dispatch_base_graph(
)
fake = buffer.from_functional()
# The fake tensor in turn is associated with a proxy node.
proxy_mode = torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
)
proxy_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
assert proxy_mode is not None
proxy = torch.fx.experimental.proxy_tensor.get_proxy_slot(
fake, proxy_mode.tracer

View file

@ -376,8 +376,8 @@ def create_functionalized_fn(
# Populate the current FunctionalTensorMode with the tokens per
# operator. See Note [FunctionalTensorMode is Stateful]
functional_tensor_mode = (
torch.utils._python_dispatch._detect_functional_mode()
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
assert functional_tensor_mode is not None
for i, k in enumerate(meta.tokens.keys()):

View file

@ -8,7 +8,7 @@ import torch.utils._pytree as pytree
from torch._C import _functionalization_reapply_views_tls as _reapply_views
from torch._ops import _get_dispatch_mode_pre_dispatch
from torch.utils._python_dispatch import (
_detect_functional_mode,
_detect_infra_mode,
_disable_infra_mode,
return_and_correct_aliasing,
TorchDispatchMode,
@ -185,7 +185,7 @@ class FunctionalTensor(torch.Tensor):
# and otherwise the sym_size() call will go to the proxy mode before hitting
# FunctionalTensor.__torch_dispatch__
functional_mode = _detect_functional_mode()
functional_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
assert functional_mode is not None
with functional_mode:
@ -219,7 +219,7 @@ class FunctionalTensor(torch.Tensor):
return [elem.tolist() for elem in self.elem]
def to(self, *args, **kwargs):
if _detect_functional_mode().export:
if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export:
# If copy is specified as pos arg, it's always the second one.
if len([arg for arg in args if isinstance(arg, bool)]) <= 1:
return super().to(*args, **{**kwargs, "copy": True})

View file

@ -178,6 +178,7 @@ def export(
dynamic_shapes,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
pre_dispatch=True,
)

View file

@ -814,7 +814,7 @@ class ONNXProgram:
... ) # Mutate buffer through in-place addition
... return output
>>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3))
>>> exported_program = torch.export.export(CustomModule(), args=inputs)
>>> exported_program = torch.export.export(CustomModule(), args=inputs).run_decompositions({})
>>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs)
>>> pprint.pprint(onnx_program.model_signature)
ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,

View file

@ -104,24 +104,25 @@ def _get_current_dispatch_mode():
return None
def _detect_functional_mode():
def _detect_infra_mode(key):
assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY]
from torch._ops import _get_dispatch_mode_pre_dispatch
pre_dispatch_functional_mode = _get_dispatch_mode_pre_dispatch(
torch._C._TorchDispatchModeKey.FUNCTIONAL
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(
key
)
post_dispatch_functional_mode = torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
post_dispatch_mode = torch._C._get_dispatch_mode(
key
)
assert (pre_dispatch_functional_mode is None) or (
post_dispatch_functional_mode is None
assert (pre_dispatch_mode is None) or (
post_dispatch_mode is None
)
if pre_dispatch_functional_mode is None:
return post_dispatch_functional_mode
if pre_dispatch_mode is None:
return post_dispatch_mode
return pre_dispatch_functional_mode
return pre_dispatch_mode
def _unset_infra_mode(key):