From d7fe3c4123ffe85fd95dadbbca741601a134d475 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Thu, 9 May 2024 14:03:37 -0700 Subject: [PATCH] [RELAND] Switch default behavoir of export IR to be predispatch (#125860) This PR switches export IR from aot-dispatch to pre-dispatch IR. **What is pre-dispatch IR and why should you care?** Currently the default IR returned by torch.export can contain only functional ATen operators after ALL pytorch dispatcher decompositions (for example, CompositeImplicitAutograd) run. In contrast, pre-dispatch IR refers to an IR that can contain all functional ATen operators (i.e., not just from the core subset), before any decomposition happens, as well as operators that manipulate autograd state. Pre-dispatch IR closely resembles eager PyTorch computation, but is still functional and serializable by torch.export. As a result: You can train the pre-dispatch IR in eager mode as the IR contains necessary information for the autograd engine to automatically generate a backward graph. You can write sound graph transformations more easily as the IR is functional. Since it is an ATen IR, it is still normalized. For example, torch.add has multiple overloads, but aten.add.Tensor is unique in this IR. If you want to get the core aten IR out of torch.export, you will need to: ``` ep = torch.export.export(M(), inputs) ep_for_core_aten = ep.run_decompositions() ``` Differential Revision: [D57172986](https://our.internmc.facebook.com/intern/diff/D57172986) Pull Request resolved: https://github.com/pytorch/pytorch/pull/125860 Approved by: https://github.com/zhxchen17 --- .../_tensor/experimental/test_tp_transform.py | 24 +-- test/export/test_experimental.py | 4 +- test/export/test_export.py | 36 ++-- test/export/test_safeguard.py | 157 ------------------ test/export/test_serialize.py | 2 +- test/onnx/test_fx_op_consistency.py | 15 +- .../dispatch_and_compile_graph.py | 6 +- .../traced_function_transforms.py | 4 +- torch/_subclasses/functional_tensor.py | 6 +- torch/export/__init__.py | 1 + torch/onnx/_internal/exporter.py | 2 +- torch/utils/_python_dispatch.py | 21 +-- 12 files changed, 66 insertions(+), 212 deletions(-) delete mode 100644 test/export/test_safeguard.py diff --git a/test/distributed/_tensor/experimental/test_tp_transform.py b/test/distributed/_tensor/experimental/test_tp_transform.py index 636870264f8..fc094150cc5 100644 --- a/test/distributed/_tensor/experimental/test_tp_transform.py +++ b/test/distributed/_tensor/experimental/test_tp_transform.py @@ -72,10 +72,10 @@ class TensorParallelTest(DTensorTestBase): inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),) with torch.no_grad(): res = model(*inputs) - exported_program = torch.export.export( - model, - inputs, - ) + exported_program = torch.export.export( + model, + inputs, + ).run_decompositions() tp_exported_program = tensor_parallel_transformation( exported_program, self.rank, @@ -110,10 +110,10 @@ class TensorParallelTest(DTensorTestBase): with torch.inference_mode(): res = model(*inputs) - exported_program = torch.export.export( - model, - inputs, - ) + exported_program = torch.export.export( + model, + inputs, + ).run_decompositions() tp_exported_program = tensor_parallel_transformation( exported_program, self.rank, @@ -146,10 +146,10 @@ class TensorParallelTest(DTensorTestBase): with torch.inference_mode(): res = model(*inputs) - exported_program = torch.export.export( - model, - inputs, - ) + exported_program = torch.export.export( + model, + inputs, + ).run_decompositions() tp_exported_program = tensor_parallel_transformation( exported_program, self.rank, diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 66f1a60ca97..2d7e88bfc11 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -47,9 +47,9 @@ def forward(self, b_submodule_buffer1, x): sin = torch.ops.aten.sin.default(x) strict_graph_0 = self.strict_graph_0 strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None - getitem = strict_mode[0]; strict_mode = None + getitem_2 = strict_mode[0]; strict_mode = None add = torch.ops.aten.add.Tensor(x, 3); x = None - return (getitem, add)""", + return (getitem_2, add)""", ) self.assertExpectedInline( diff --git a/test/export/test_export.py b/test/export/test_export.py index 7c711ffecb1..586fc403da9 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -576,7 +576,6 @@ class TestExport(TestCase): ) # Predispatch has different expected results - @testing.expectedFailureSerDerPreDispatch def test_torch_fn(self): class M1(torch.nn.Module): def __init__(self): @@ -591,7 +590,7 @@ class TestExport(TestCase): x = x + x return x - ep1 = export(M1(), (torch.randn(3, 3),)) + ep1 = export(M1(), (torch.randn(3, 3),)).run_decompositions() expected_result = [ ("linear_1", "builtin_function_or_method.linear"), ("linear_1", "builtin_function_or_method.linear"), @@ -616,7 +615,9 @@ class TestExport(TestCase): x = torch.add(x, x) return x - ep2 = export(M2(), (torch.randn(3, 3), torch.randn(3, 3), torch.randn(3))) + ep2 = export( + M2(), (torch.randn(3, 3), torch.randn(3, 3), torch.randn(3)) + ).run_decompositions() expected_result = [ ("linear_1", "builtin_function_or_method.linear"), ("linear_1", "builtin_function_or_method.linear"), @@ -3061,8 +3062,6 @@ def forward(self, x): with self.assertRaisesRegex(ValueError, "Trying to flatten user inputs"): exported_program.module()(torch.rand(2, 3), torch.rand(2, 3)) - @testing.expectedFailureSerDerPreDispatch # linear shouldn't decompose - @testing.expectedFailurePreDispatchRunDecomp # no action needed here def test_export_decomps_simple(self): class M(torch.nn.Module): def __init__(self): @@ -3077,9 +3076,6 @@ def forward(self, x): ep = export(m, inp) state_dict = ep.state_dict - FileCheck().check_count("torch.ops.aten.t.default", 1, exactly=True).run( - ep.graph_module.code - ) self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp))) core_aten_ep = ep.run_decompositions() @@ -3786,8 +3782,8 @@ def forward(self, x): inp = (torch.randn(4, 4),) mod = Foo() - ep_strict = torch.export.export(mod, inp) - ep_non_strict = torch.export.export(mod, inp, strict=False) + ep_strict = torch.export.export(mod, inp).run_decompositions() + ep_non_strict = torch.export.export(mod, inp, strict=False).run_decompositions() gm_unflat_non_strict = unflatten(ep_non_strict) self.assertTrue(hasattr(gm_unflat_non_strict, "bar")) @@ -3804,8 +3800,8 @@ graph(): %x : [num_users=1] = placeholder[target=x] %weight : [num_users=1] = get_attr[target=weight] %bias : [num_users=1] = get_attr[target=bias] - %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%weight,), kwargs = {}) - %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %x, %t), kwargs = {}) + %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%weight, [1, 0]), kwargs = {}) + %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %x, %permute), kwargs = {}) return addmm""", ) @@ -3876,9 +3872,8 @@ graph(): %x : [num_users=1] = placeholder[target=x] %weight : [num_users=1] = get_attr[target=weight] %bias : [num_users=1] = get_attr[target=bias] - %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%weight,), kwargs = {}) - %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %x, %t), kwargs = {}) - return addmm""", + %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %weight, %bias), kwargs = {}) + return linear""", ) self.assertExpectedInline( str(gm_unflat_non_strict.bar_different.leaf.linear.graph).strip(), @@ -3887,9 +3882,8 @@ graph(): %add_2 : [num_users=1] = placeholder[target=add_2] %weight : [num_users=1] = get_attr[target=weight] %bias : [num_users=1] = get_attr[target=bias] - %t_1 : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%weight,), kwargs = {}) - %addmm_1 : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %add_2, %t_1), kwargs = {}) - return addmm_1""", + %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%add_2, %weight, %bias), kwargs = {}) + return linear_1""", ) gm_flat_non_strict = ep_non_strict.module() @@ -4434,7 +4428,7 @@ def forward(self, x): inps = (torch.ones(5),) - ep = torch.export.export(M(), inps) + ep = torch.export.export(M(), inps).run_decompositions() self.assertExpectedInline( str(ep.graph_module.code.strip()), """\ @@ -4952,7 +4946,9 @@ class TestOneOffModelExportResult(TestCase): k = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") v = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") - ep = torch.export.export(ScaledDotProductAttention(), (q, k, v)) + ep = torch.export.export( + ScaledDotProductAttention(), (q, k, v) + ).run_decompositions() self.assertExpectedInline( ep.graph_module.code.strip(), """\ diff --git a/test/export/test_safeguard.py b/test/export/test_safeguard.py deleted file mode 100644 index 1d4ffa030c7..00000000000 --- a/test/export/test_safeguard.py +++ /dev/null @@ -1,157 +0,0 @@ -# Owner(s): ["oncall: export"] -import unittest - -import torch -import torch._dynamo as torchdynamo -from torch.export import export -from torch.testing._internal.common_utils import run_tests, TestCase - - -@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") -class TestSafeguard(TestCase): - # If the autograd state doesn't change, dynamo eliminates autograd state manager op and later export can succeed. - # Otherwise, autograd can be preserved in the produced gragh, and export will fail. - def test_global_autograd(self): - class F1(torch.nn.Module): - def forward(self, a): - with torch.no_grad(): - b = a + a - return b - - f1 = F1() - - class F2(torch.nn.Module): - def forward(self, a): - with torch.enable_grad(): - b = a + a - return b - - f2 = F2() - - class F3(torch.nn.Module): - def forward(self, a): - with torch.set_grad_enabled(False): - b = a + a - return b - - f3 = F3() - - class F4(torch.nn.Module): - def forward(self, a): - with torch.set_grad_enabled(True): - b = a + a - return b - - f4 = F4() - - a = torch.randn(10) - with torch.no_grad(): - export(f1, (a,)) - export(f2, (a,)) - export(f3, (a,)) - export(f4, (a,)) - - with torch.enable_grad(): - export(f2, (a,)) - export(f4, (a,)) - - with self.assertRaisesRegex( - RuntimeError, "Encountered autograd state manager op.*" - ): - export(f1, (a,)) - - with self.assertRaisesRegex( - RuntimeError, "Encountered autograd state manager op.*" - ): - export(f3, (a,)) - - def test_tensor_autograd(self): - # dynamo errors when Tensor.requires_grad_ change the autograd state - class F1(torch.nn.Module): - def forward(self, a): - a.requires_grad_(True) - b = a + a - return b - - f1 = F1() - - # dynamo errors when Tensor.requires_grad_ change the autograd state - class F2(torch.nn.Module): - def forward(self, a): - a.requires_grad_(False) - b = a + a - return b - - f2 = F2() - - # dynamo always errors on Tensor.requires_grad - class F3(torch.nn.Module): - def forward(self, a): - a.requires_grad = False - b = a + a - return b - - f3 = F3() - - export(f1, (torch.randn(10, requires_grad=True),)) - export(f2, (torch.randn(10, requires_grad=False),)) - - with self.assertRaises(RuntimeError): - export(f1, (torch.randn(10, requires_grad=False),)) - with self.assertRaises(RuntimeError): - export(f2, (torch.randn(10, requires_grad=True),)) - with self.assertRaises(RuntimeError): - export(f3, (torch.randn(10, requires_grad=False),)) - - def test_global_autograd_exempt_predispatch(self): - class F1(torch.nn.Module): - def forward(self, a): - with torch.no_grad(): - b = a + a - return b - - f1 = F1() - - class F2(torch.nn.Module): - def forward(self, a): - with torch.enable_grad(): - b = a + a - return b - - f2 = F2() - - class F3(torch.nn.Module): - def forward(self, a): - with torch.set_grad_enabled(False): - b = a + a - return b - - f3 = F3() - - class F4(torch.nn.Module): - def forward(self, a): - with torch.set_grad_enabled(True): - b = a + a - return b - - f4 = F4() - - a = torch.randn(10) - - from torch.export._trace import _export - - with torch.no_grad(): - _export(f1, (a,), pre_dispatch=True) - _export(f2, (a,), pre_dispatch=True) - _export(f3, (a,), pre_dispatch=True) - _export(f4, (a,), pre_dispatch=True) - - with torch.enable_grad(): - _export(f1, (a,), pre_dispatch=True) - _export(f2, (a,), pre_dispatch=True) - _export(f3, (a,), pre_dispatch=True) - _export(f4, (a,), pre_dispatch=True) - - -if __name__ == "__main__": - run_tests() diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 9a862984321..bfa709f7e74 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -183,7 +183,7 @@ class TestSerialize(TestCase): torch.ones([512]), torch.ones([512]), ), - ) + ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 004574d3e8c..0a1cabf19d4 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -1873,7 +1873,20 @@ def _run_test_output_match( == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM ): try: - model = torch.export.export(model, inputs) + # TODO (tugsbayasgalan) Migrate to pre-dispatch IR + # BUG1: python test/onnx/test_fx_op_consistency.py -k test_output_match_triu_cpu_int32 + # has unexpected success, but don't know how to remove from xfail list + # BUG2: User output to_sparse is not in the correct order or is not found in the + # exported program's user_output list (https://github.com/pytorch/pytorch/issues/124328) + # python test/onnx/test_fx_op_consistency.py -k test_output_match_to_sparse_cpu_float32 + # BUG3: [ShapeInferenceError] Inference error(s): (op_type:aten_view, node name: aten_view_4): + # [ShapeInferenceError] + # Inference error(s): (op_type:Reshape, node name: n1): [ShapeInferenceError] Invalid position of 0. + # python test/onnx/test_fx_op_consistency.py -k test_output_match_stack_cpu_int32 + from torch.export import _trace + + model = _trace._export(model, inputs, pre_dispatch=False) + except AssertionError as e: # NOTE: avoid fake_mode detection bug in torch.export.export pytest.xfail( diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index 2c644e1a4c5..78dbb260332 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -11,10 +11,12 @@ import torch.utils._pytree as pytree import torch.utils.dlpack from torch import Tensor from torch._dispatch.python import enable_python_dispatcher + from torch._dynamo.utils import lazy_format_graph_code from torch._logging import getArtifactLogger, trace_structured from torch._subclasses.functional_tensor import FunctionalTensorMode from torch.fx.experimental.proxy_tensor import make_fx +from torch.utils._python_dispatch import _detect_infra_mode from .. import config from .functional_utils import ( @@ -116,9 +118,7 @@ def aot_dispatch_base_graph( ) fake = buffer.from_functional() # The fake tensor in turn is associated with a proxy node. - proxy_mode = torch._C._get_dispatch_mode( - torch._C._TorchDispatchModeKey.PROXY - ) + proxy_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.PROXY) assert proxy_mode is not None proxy = torch.fx.experimental.proxy_tensor.get_proxy_slot( fake, proxy_mode.tracer diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py index cce76e0108f..42c4b6ebd0f 100644 --- a/torch/_functorch/_aot_autograd/traced_function_transforms.py +++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -376,8 +376,8 @@ def create_functionalized_fn( # Populate the current FunctionalTensorMode with the tokens per # operator. See Note [FunctionalTensorMode is Stateful] - functional_tensor_mode = ( - torch.utils._python_dispatch._detect_functional_mode() + functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL ) assert functional_tensor_mode is not None for i, k in enumerate(meta.tokens.keys()): diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index fb2a81b8aeb..1762059eedf 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -8,7 +8,7 @@ import torch.utils._pytree as pytree from torch._C import _functionalization_reapply_views_tls as _reapply_views from torch._ops import _get_dispatch_mode_pre_dispatch from torch.utils._python_dispatch import ( - _detect_functional_mode, + _detect_infra_mode, _disable_infra_mode, return_and_correct_aliasing, TorchDispatchMode, @@ -185,7 +185,7 @@ class FunctionalTensor(torch.Tensor): # and otherwise the sym_size() call will go to the proxy mode before hitting # FunctionalTensor.__torch_dispatch__ - functional_mode = _detect_functional_mode() + functional_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) assert functional_mode is not None with functional_mode: @@ -219,7 +219,7 @@ class FunctionalTensor(torch.Tensor): return [elem.tolist() for elem in self.elem] def to(self, *args, **kwargs): - if _detect_functional_mode().export: + if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export: # If copy is specified as pos arg, it's always the second one. if len([arg for arg in args if isinstance(arg, bool)]) <= 1: return super().to(*args, **{**kwargs, "copy": True}) diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 5eea452dbe1..5396646f675 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -178,6 +178,7 @@ def export( dynamic_shapes, strict=strict, preserve_module_call_signature=preserve_module_call_signature, + pre_dispatch=True, ) diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index 7831a362aef..5315e034ecd 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -814,7 +814,7 @@ class ONNXProgram: ... ) # Mutate buffer through in-place addition ... return output >>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3)) - >>> exported_program = torch.export.export(CustomModule(), args=inputs) + >>> exported_program = torch.export.export(CustomModule(), args=inputs).run_decompositions({}) >>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs) >>> pprint.pprint(onnx_program.model_signature) ExportGraphSignature(input_specs=[InputSpec(kind=, diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index ec24f006a70..d22b550c6d1 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -104,24 +104,25 @@ def _get_current_dispatch_mode(): return None -def _detect_functional_mode(): +def _detect_infra_mode(key): + assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY] from torch._ops import _get_dispatch_mode_pre_dispatch - pre_dispatch_functional_mode = _get_dispatch_mode_pre_dispatch( - torch._C._TorchDispatchModeKey.FUNCTIONAL + pre_dispatch_mode = _get_dispatch_mode_pre_dispatch( + key ) - post_dispatch_functional_mode = torch._C._get_dispatch_mode( - torch._C._TorchDispatchModeKey.FUNCTIONAL + post_dispatch_mode = torch._C._get_dispatch_mode( + key ) - assert (pre_dispatch_functional_mode is None) or ( - post_dispatch_functional_mode is None + assert (pre_dispatch_mode is None) or ( + post_dispatch_mode is None ) - if pre_dispatch_functional_mode is None: - return post_dispatch_functional_mode + if pre_dispatch_mode is None: + return post_dispatch_mode - return pre_dispatch_functional_mode + return pre_dispatch_mode def _unset_infra_mode(key):