From 2dbadd1eae7b0ffcc6884df53a2003e2917fe977 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Wed, 26 Jul 2023 16:21:29 +0000 Subject: [PATCH] [export] Remove experimental runtime assertion configs from export API. (#105043) Test Plan: CI Differential Revision: D47390794 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105043 Approved by: https://github.com/larryliu0820 --- test/export/test_passes.py | 271 +----------------- test/export/test_serialize.py | 59 ++-- test/export/test_upgrade.py | 2 +- test/test_out_dtype_op.py | 1 - torch/_export/__init__.py | 13 +- torch/_export/exported_program.py | 34 ++- ...runtime_assertions_for_constraints_pass.py | 197 +++++++------ torch/_export/serde/upgrade.py | 2 +- 8 files changed, 170 insertions(+), 409 deletions(-) diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 401cfad4d47..8676164b61d 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -14,7 +14,6 @@ from torch.testing import FileCheck from torch._dynamo.eval_frame import is_dynamo_supported from torch._export import export, dynamic_dim from torch._export.constraints import constrain_as_value, constrain_as_size -from torch._export.exported_program import ExportGraphSignature from torch._export.passes import ( ReplaceViewOpsWithViewCopyOpsPass, ) @@ -27,8 +26,7 @@ from torch._export.passes.functionalize_side_effectful_ops_pass import ( ) from functorch.experimental.control_flow import cond from torch.fx.passes.operator_support import OperatorSupport -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition -from torch.fx._symbolic_trace import symbolic_trace +from torch.fx.passes.infra.partitioner import Partition from torch.utils._pytree import tree_flatten @@ -94,12 +92,6 @@ class TestPasses(TestCase): ep = export(M(), (x,), constraints=[dynamic_dim(x, 1) >= 2, dynamic_dim(x, 1) <= 6]) - num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg) - num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default) - - self.assertEqual(num_assert, 3) - self.assertEqual(num_scalar_tensor, 3) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): ep(torch.zeros(2, 7, 3)) @@ -125,12 +117,6 @@ class TestPasses(TestCase): ep = export(M(), (x, y), constraints=constraints) - num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg) - num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default) - - self.assertEqual(num_assert, 6) - self.assertEqual(num_scalar_tensor, 6) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) @@ -156,13 +142,6 @@ class TestPasses(TestCase): ep = export(M(), (x, y), constraints=constraints) - num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg) - num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default) - - # there are 3 asserts from y and 2 from dynamic x dims and 1 from static x dim - self.assertEqual(num_assert, 6) - self.assertEqual(num_scalar_tensor, 6) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) @@ -194,13 +173,6 @@ class TestPasses(TestCase): ep = export(M(), (x, y), constraints=constraints) - num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg) - num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default) - - # there are 4 asserts from y and 3 from x - self.assertEqual(num_assert, 7) - self.assertEqual(num_scalar_tensor, 7) - with self.assertRaisesRegex(RuntimeError, "Input arg0_1"): ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) @@ -278,12 +250,6 @@ class TestPasses(TestCase): mod = M() ep = export(mod, (x,)) - num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg) - num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default) - # 1 constraint for shape of x, 2 constraints for b - self.assertEqual(num_assert, 3) - self.assertEqual(num_scalar_tensor, 3) - with self.assertRaisesRegex(RuntimeError, r"_local_scalar_dense_default is outside of inline constraint \[2, 5\]."): ep(torch.tensor([6])) @@ -344,6 +310,8 @@ class TestPasses(TestCase): y = torch.tensor([5]) mod = M() ep = export(mod, (torch.tensor(True), x, y)) + + with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."): ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6])) @@ -409,239 +377,6 @@ class TestPasses(TestCase): "torch.ops.aten.sym_constrain_range.default", 0, exactly=True ).run(gm.code) - dep_token_node = next(n for n in gm.graph.nodes if n.name == "dep_token3") - constrain_node = next( - n - for n in gm.graph.nodes - if n.target == torch.ops.aten._functional_sym_constrain_range - ) - self.assertEqual(constrain_node.kwargs["dep_token"], dep_token_node) - - def test_functionalize_input_constraints(self) -> None: - def f(x): - return x * 2 - - inp = torch.zeros(4, 8) - ep = torch._export.export( - f, - (inp,), - constraints=[ - dynamic_dim(inp, 0) < 10, - dynamic_dim(inp, 0) >= 3, - ], - ) - FileCheck().check_count( - "torch.ops.aten._assert_async.msg", 3, exactly=True - ).run(ep.graph_module.code) - - gm = ep.transform(_FunctionalizeSideEffectfulOpsPass()).graph_module - with self.assertRaisesRegex( - RuntimeError, - r"Input arg0_1.shape\[0\] is outside of specified dynamic range \[3, 9\]", - ): - gm(torch.ones(11, 8)) - - inp = torch.ones(6, 8) - self.assertEqual(gm(inp)[0], f(inp)) - FileCheck().check_count( - "torch.ops.aten._functional_assert_async.msg", 3, exactly=True - ).run(gm.code) - FileCheck().check_count( - "torch.ops.aten._assert_async.msg", 0, exactly=True - ).run(gm.code) - - def test_functionalization(self) -> None: - def f(x, y): - a = x.item() - constrain_as_size(a, 4, 7) - return x + 4, x + y * 2 - - inps = (torch.tensor([5]), torch.zeros((3, 4))) - ep = torch._export.export( - f, - inps, - constraints=[dynamic_dim(inps[1], 1) < 6], - _functionalize_runtime_assertions=True, - ) - FileCheck().check_count( - "torch.ops.aten._functional_sym_constrain_range", 1, exactly=True - ).run(ep.graph_module.code) - inps = (torch.tensor([7]), torch.ones((3, 5))) - self.assertTrue(torch._dynamo.utils.same(ep(*inps), f(*inps))) - - def test_functionalization_with_native_python_assertion(self) -> None: - def f(x): - b = x.sin() - assert x[0] == 3 - return x.cos() + b - - inp = torch.Tensor([3, 4, 5]) - ep = torch._export.export(f, (inp,), _functionalize_runtime_assertions=True) - - # Check native assertion has corresponding functional assertion nodes generated. - select_int_node = next( - n - for n in ep.graph_module.graph.nodes - if n.target == torch.ops.aten.select.int - ) - equal_scalar_node = select_int_node.next - dep_token_node = next( - n - for n in ep.graph_module.graph.nodes - if ( - n.target == torch.ops.aten._functional_assert_async.msg - and n.args[0] == equal_scalar_node - ) - ) - self.assertIn( - "call_function[target=torch.ops.aten._functional_assert_async.msg]" - "(args = (%eq_scalar, assertion error), kwargs = {dep_token: %dep_token1}", - dep_token_node.format_node(), - ) - - def test_functionalization_with_mutated_buffer(self) -> None: - buf = torch.ones(6, 2) - weight = 0.01 - bias = 0.2 - d_in = 3 - d_out = 4 - - class Foo(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.register_buffer("buf", buf) - - self.linear = torch.nn.Linear(d_in, d_out) - self.linear.weight.data.fill_(weight) - self.linear.bias.data.fill_(bias) - - def forward(self, x): - self.buf.add_(5) - return self.linear(x).cos() + self.buf.sum() - - inp = torch.ones(4, 3) - ep = torch._export.export( - Foo(), - (inp,), - constraints=[dynamic_dim(inp, 0) >= 3], - _functionalize_runtime_assertions=True, - ) - - gs = ep.graph_signature - self.assertEqual( - gs, - ExportGraphSignature( - parameters=["L__self___linear.weight", "L__self___linear.bias"], - buffers=["L__self___buf"], - user_inputs=["arg3_1"], - user_outputs=["add_tensor_1"], - inputs_to_parameters={ - "arg0_1": "L__self___linear.weight", - "arg1_1": "L__self___linear.bias", - }, - inputs_to_buffers={"arg2_1": "L__self___buf"}, - buffers_to_mutate={"add_tensor": "L__self___buf"}, - backward_signature=None, - assertion_dep_token={2: "dep_token7"}, - ), - ) - outputs = next(n for n in ep.graph.nodes if n.op == "output").args[0] - self.assertEqual( - [str(o) for o in outputs], - ["add_tensor", "add_tensor_1", "dep_token7"], - ) - self.assertEqual( - len(outputs), len(gs.buffers_to_mutate) + len(gs.user_outputs) + 1, - ) - inp = torch.randn(5, 3) - self.assertTrue( - torch._dynamo.utils.same( - # Directly check run output of `ep.graph_module` which is - # functionalized. - ep.graph_module( - torch.full((d_out, d_in), weight), - torch.full((d_out,), bias), - buf.clone(), - inp, - ), - (buf.add(5), Foo()(inp), torch.empty(0)), - ) - ) - self.assertTrue(torch._dynamo.utils.same(ep(inp), Foo()(inp))) - - def test_graph_partition_after_assertion_functionalization(self) -> None: - def f1(a, b): - add = a + b - add_1 = add + b - add_2 = add_1 + add - - relu_1 = add_2.relu() # blocked by this - - add_3 = add_2 + relu_1 - add_4 = add_2 + add_3 - return add_4, add_2 - - partitioner1 = CapabilityBasedPartitioner( - graph_module=symbolic_trace(f1), - operator_support=_AddOperatorSupport(), - ) - partitions1 = partitioner1.propose_partitions() - - self.assertEqual( - _to_partition_names(partitions1), - [{"add_3", "add_4"}, {"add", "add_1", "add_2"}], - ) - - def f2(a, b): - add = a + b - add_1 = add + b - add_2 = add_1 + add - - assert add_1[0] == 5 - - relu_1 = add_2.relu() # blocked by this - - add_3 = add_2 + relu_1 - add_4 = add_2 + add_3 - return add_4, add_2 - - inps = (torch.tensor([1, 3, 2]), torch.tensor([2, 3, 4])) - gm = export( - f2, - inps, - constraints=[dynamic_dim(inps[0], 0) == dynamic_dim(inps[1], 0)], - _functionalize_runtime_assertions=True, - ).graph_module - partitioner2 = CapabilityBasedPartitioner( - graph_module=gm, - operator_support=_AtenAddOperatorSupport(), - ) - partitions2 = partitioner2.propose_partitions() - - self.assertEqual( - _to_partition_names(partitions2), - [ - {"add_tensor_3", "add_tensor_4"}, - {"add_tensor_1", "add_tensor_2", "add_tensor"}, - ] - ) - - fused_gm1 = partitioner1.fuse_partitions(partitions1) - fused_gm2 = partitioner2.fuse_partitions(partitions2) - - inps = (torch.tensor([1, 4, 6]), torch.tensor([2, 4, 6])) - self.assertTrue( - torch._dynamo.utils.same(fused_gm1(*inps)[0], fused_gm2(*inps)[0]), - ) - - # Sub-module `fused_1` is for logic `add = ..., ..., add_2 = ...` - output_names1 = _get_output_names(fused_gm1.get_submodule("fused_1")) - output_names2 = _get_output_names(fused_gm2.get_submodule("fused_1")) - - self.assertEqual(output_names1, ["add_2"]) - # The extra output `add_tensor_1` is consumed by assertion. - self.assertEqual(output_names2, ["add_tensor_1", "add_tensor_2"]) - if __name__ == '__main__': run_tests() diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index f065095b463..7e8f2b08fa8 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -199,39 +199,34 @@ class TestDeserialize(TestCase): self.assertEqual(len(ep.graph.nodes), len(deserialized_ep.graph.nodes)) for node1, node2 in zip(ep.graph.nodes, deserialized_ep.graph.nodes): - # Check "val" metadata - val1 = node1.meta.get("val", None) - val2 = node2.meta.get("val", None) + self.assertEqual(node1.op, node2.op) + if node1.op == "call_function": + # Check "val" metadata + val1 = node1.meta.get("val", None) + val2 = node2.meta.get("val", None) + if val1 is None or val2 is None: + # Either both are None + self.assertEqual(val1, val2) + elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor): + # Or both are fake tensors with the same shape/dtype + self.assertEqual(len(val1.shape), len(val2.shape)) + for s1, s2 in zip(val1.shape, val2.shape): + if is_concrete_int(s1) and is_concrete_int(s2): + self.assertEqual(s1, s2) + else: + self.assertEqual(str(s1), str(s2)) + self.assertEqual(val1.dtype, val2.dtype) + elif isinstance(val1, list) and isinstance(val2, list): + # Or both are fake tensors lists with one element and with the + # same shape/dtype + self.assertTrue(len(val1) == 1 and len(val2) == 1) + self.assertEqual(val1[0].shape, val2[0].shape) + self.assertEqual(val1[0].dtype, val2[0].dtype) + else: + # For expressions like 's0 < 10' can only compare through string + self.assertEqual(str(val1), str(val2)) - if val1 is None or val2 is None: - # Either both are None - self.assertEqual(val1, val2) - elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor): - # Or both are fake tensors with the same shape/dtype - self.assertEqual(len(val1.shape), len(val2.shape)) - for s1, s2 in zip(val1.shape, val2.shape): - if is_concrete_int(s1) and is_concrete_int(s2): - self.assertEqual(s1, s2) - else: - self.assertEqual(str(s1), str(s2)) - self.assertEqual(val1.dtype, val2.dtype) - elif isinstance(val1, list) and isinstance(val2, list): - # Or both are fake tensors lists with one element and with the - # same shape/dtype - self.assertTrue(len(val1) == 1 and len(val2) == 1) - self.assertEqual(val1[0].shape, val2[0].shape) - self.assertEqual(val1[0].dtype, val2[0].dtype) - else: - # For expressions like 's0 < 10' can only compare through string - self.assertEqual(str(val1), str(val2)) - - # Check "stack_trace" metadata - if "None" in node1.meta.get("stack_trace"): - self.assertTrue( - node2.meta.get("stack_trace") is None - or "None" in node2.meta.get("stack_trace") - ) - else: + # Check "stack_trace" metadata self.assertEqual( node1.meta.get("stack_trace", None), node2.meta.get("stack_trace", None), diff --git a/test/export/test_upgrade.py b/test/export/test_upgrade.py index 8eec5e37646..69913c12247 100644 --- a/test/export/test_upgrade.py +++ b/test/export/test_upgrade.py @@ -117,7 +117,7 @@ def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Opti return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc') inputs = (torch.ones([2, 3]) * 4, 2.) - ep = export(fn, inputs, {}, [], _add_runtime_assertions=False) + ep = export(fn, inputs, {}, []) compiler_opset_version = {"aten": 4} model_opset_version = {"aten": 3} upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS) diff --git a/test/test_out_dtype_op.py b/test/test_out_dtype_op.py index b897fe3bd70..cbc25826136 100644 --- a/test/test_out_dtype_op.py +++ b/test/test_out_dtype_op.py @@ -59,7 +59,6 @@ class TestOutDtypeOp(TestCase): ep = torch._export.export( m, (x,), - _add_runtime_assertions=False, ) FileCheck().check("torch.ops.higher_order.out_dtype").check("aten.mm.default").run(ep.graph_module.code) self.assertTrue(torch.allclose(m(x), ep(x))) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 771b18e64cf..9c1e4fea887 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -40,7 +40,7 @@ from .exported_program import ( ExportGraphSignature, ) from .passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass - +from .passes.add_runtime_assertions_for_constraints_pass import _AddRuntimeAssertionsForInlineConstraintsPass # Note - [On Export Dynamic Dimension UX] # @@ -127,9 +127,6 @@ def export( args: Tuple[Any], kwargs: Optional[Dict[str, Any]] = None, constraints: Optional[List[Constraint]] = None, - *, - _add_runtime_assertions=True, - _functionalize_runtime_assertions=False, ) -> ExportedProgram: """ Traces either an nn.Module's forward function or just a callable with PyTorch @@ -324,11 +321,9 @@ def export( equality_constraints, ) - if _add_runtime_assertions: - exported_program = exported_program._add_runtime_assertions( - functionalize=_functionalize_runtime_assertions, - ) - + exported_program = exported_program.transform( + _AddRuntimeAssertionsForInlineConstraintsPass(range_constraints, equality_constraints) + ) return exported_program.transform(_ReplaceSymSizeOpPass()) except (ConstraintViolationError, ValueRangeError) as e: diff --git a/torch/_export/exported_program.py b/torch/_export/exported_program.py index b874a43384a..a9078a5e45e 100644 --- a/torch/_export/exported_program.py +++ b/torch/_export/exported_program.py @@ -1,16 +1,19 @@ -from collections import defaultdict import copy import dataclasses -import sympy +from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple, Union -from torch._functorch.aot_autograd import FQN, GraphInputName, GraphOutputName + +import sympy import torch -from torch.fx.passes.infra.pass_manager import PassManager import torch.fx._pytree as fx_pytree import torch.utils._pytree as pytree -from torch.fx.experimental.symbolic_shapes import SymInt +from torch import fx +from torch._functorch.aot_autograd import FQN, GraphInputName, GraphOutputName from torch._subclasses.fake_tensor import FakeTensor +from torch.fx.experimental.symbolic_shapes import SymInt +from torch.fx.passes.infra.pass_manager import PassManager + from . import error from .pass_base import PassType from .passes.add_runtime_assertions_for_constraints_pass import ( @@ -122,7 +125,8 @@ class ExportedProgram: f"{received_spec}" ) - param_buffer_values = (value for _, value in self.state_dict.items()) + param_buffer_values = tuple(value for _, value in self.state_dict.items()) + self._check_input_constraints(*param_buffer_values, *args) with torch.no_grad(): res = torch.fx.Interpreter(self.graph_module).run( @@ -208,6 +212,24 @@ class ExportedProgram: transformed_ep.graph_module.meta.update(res.graph_module.meta) return transformed_ep + def _check_input_constraints(self, *args): + # TODO(zhxchen17) Remove _add_runtime_assertions. + # TODO(zhxchen17) Don't generate a runtime graph on the fly. + _assertion_graph = fx.GraphModule({}, fx.Graph()) + for p in self.graph.nodes: + if p.op != "placeholder": + continue + new_p = _assertion_graph.graph.placeholder(p.name) + new_p.meta = p.meta + _assertion_graph.graph.output(()) + _assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass( + self.range_constraints, + self.equality_constraints, + )(_assertion_graph) + assert _assertion_graph_res is not None + _assertion_graph = _assertion_graph_res.graph_module + _assertion_graph(*args) + def _add_runtime_assertions( self, functionalize: bool, diff --git a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py index cba6d6199b1..7882060d26f 100644 --- a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +++ b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -50,7 +50,7 @@ def _convert_range_to_int(range: RangeConstraint): return min_val, max_val -class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase): +class _AddRuntimeAssertionsForInlineConstraintsPass(ExportPassBase): def __init__( self, range_constraints: Dict[sympy.Symbol, RangeConstraint], @@ -60,6 +60,101 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase): self.range_constraints: Dict[sympy.Symbol, RangeConstraint] = range_constraints self.equality_constraints: List[Tuple[InputDim, InputDim]] = equality_constraints + def _assert_range_constraint(self, proxy, lower, upper, assert_msg): + if lower > -math.inf: + self._insert_assert_async(operator.ge, proxy, lower, assert_msg) + + if upper < math.inf: + self._insert_assert_async(operator.le, proxy, upper, assert_msg) + + def _insert_assert_async(self, operator, lower, upper, assert_msg): + """ + Inserts assert_async call_function nodes in the graph. This function is + called **during** the interpreter-based pass. + """ + cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata()) + cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata()) + super().call_operator( + torch.ops.aten._assert_async.msg, + (cmp_tensor, assert_msg), + {}, + self._create_dummy_node_metadata(), + ) + + def call_operator(self, op, args, kwargs, meta) -> ProxyValue: + ret = super().call_operator(op, args, kwargs, meta) + if "val" not in meta: + return ret + + val = meta["val"] + + # In general, we may have to deal the case such as: ret[1].shape[0]. + # We need first find out what symbols require assertion, then we need to follow the path + # from ret to the symbol, construct the proxies along the way and construct the messages + # piece-wise at the same time. + # + # We use post-order traversal to collect all the proxies callbacks needed, construct + # the error message callbacks, and at the top-level traversal tree we execute all the callbacks. + # We need the callbacks because, in order to call the function to create a proxy for shape[0], we + # need the proxy for shape, which further requries the proxy for ret[1], etc. + def add_assertions(val): + call_backs = [] + messages = [] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + symbol = val.node._expr + if isinstance(symbol, sympy.Symbol) and symbol.name.startswith("i"): + # We only care about unbacked symints for these inline + # constraints, which are prefixed with 'i' + constraint = self.range_constraints[symbol] + min_val, max_val = _convert_range_to_int(constraint) + assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." + call_backs.append( + partial(self._assert_range_constraint, lower=min_val, upper=max_val) + ) + messages.append(assert_msg) + elif isinstance(val, torch.Tensor): + for i, sym in enumerate(val.shape): + cbs, msgs = add_assertions(sym) + for cb, msg in zip(cbs, msgs): + def sym_size_cb(proxy, assert_msg, dim): + dim_proxy = super( + _AddRuntimeAssertionsForInlineConstraintsPass, + self + ).call_operator( + torch.ops.aten.sym_size.int, + (proxy, dim), + {}, + self._create_dummy_node_metadata(), + ) + cb(proxy=dim_proxy, assert_msg=assert_msg) + call_backs.append(partial(sym_size_cb, dim=i)) + messages.append(f".shape[{i}]" + msg) + return call_backs, messages + callbacks, messages = add_assertions(val) + for cb, msg in zip(callbacks, messages): + cb(proxy=ret, assert_msg=f"{ret.node}" + msg) + return ret + + def call(self, graph_module): + # Add runtime asserts for inline constraints + val = super().call(graph_module) + + # Populate the stack trace with dummy vals to respect IR + for node in val.graph_module.graph.nodes: + if not hasattr(node.meta, "stack_trace"): + node.meta["stack_trace"] = traceback.format_exc(-1) + + return PassResult(val.graph_module, val.modified) + + +class _AddRuntimeAssertionsForConstraintsPass(_AddRuntimeAssertionsForInlineConstraintsPass): + def __init__( + self, + range_constraints: Dict[sympy.Symbol, RangeConstraint], + equality_constraints: List[Tuple[InputDim, InputDim]], + ): + super().__init__(range_constraints, equality_constraints) + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph_module = copy.deepcopy(graph_module) graph = graph_module.graph @@ -98,11 +193,10 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase): if isinstance(shape, SymInt): # If the shape is dynamic, add range assertions symbol = shape.node._expr - assert symbol in self.range_constraints - - self._insert_range_assert_inplace( - graph, input_dim, dim_node, self.range_constraints[symbol] - ) + if symbol in self.range_constraints: + self._insert_range_assert_inplace( + graph, input_dim, dim_node, self.range_constraints[symbol] + ) else: # If no dynamism is specified, we assume all dimensions # # are specialized @@ -112,20 +206,13 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase): ) # Add runtime assertions on equality constraints on the inputs - with graph.inserting_after( - list(inputdim_to_node.values())[-1] - ): - self._insert_equality_assert_inplace(graph, inputdim_to_node) + if len(inputdim_to_node) > 0: + with graph.inserting_after( + list(inputdim_to_node.values())[-1] + ): + self._insert_equality_assert_inplace(graph, inputdim_to_node) - # Add runtime asserts for inline constraints - val = super().call(graph_module) - - # Populate the stack trace with dummy vals to respect IR - for node in val.graph_module.graph.nodes: - if not hasattr(node.meta, "stack_trace"): - node.meta["stack_trace"] = traceback.format_exc(-1) - - return PassResult(val.graph_module, val.modified) + return super().call(graph_module) def _insert_specialized_shape_assert_inplace( self, graph: torch.fx.Graph, input_dim: InputDim, dim_node: torch.fx.Node, shape: int, @@ -202,75 +289,3 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase): _ = graph.call_function( torch.ops.aten._assert_async.msg, (cmp_tensor_node, assert_msg) ) - - def call_operator(self, op, args, kwargs, meta) -> ProxyValue: - ret = super().call_operator(op, args, kwargs, meta) - if "val" not in meta: - return ret - - val = meta["val"] - - # In general, we may have to deal the case such as: ret[1].shape[0]. - # We need first find out what symbols require assertion, then we need to follow the path - # from ret to the symbol, construct the proxies along the way and construct the messages - # piece-wise at the same time. - # - # We use post-order traversal to collect all the proxies callbacks needed, construct - # the error message callbacks, and at the top-level traversal tree we execute all the callbacks. - # We need the callbacks because, in order to call the function to create a proxy for shape[0], we - # need the proxy for shape, which further requries the proxy for ret[1], etc. - def add_assertions(val): - call_backs = [] - messages = [] - if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): - symbol = val.node._expr - if isinstance(symbol, sympy.Symbol) and symbol.name.startswith("i"): - # We only care about unbacked symints for these inline - # constraints, which are prefixed with 'i' - constraint = self.range_constraints[symbol] - min_val, max_val = _convert_range_to_int(constraint) - assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." - call_backs.append( - partial(self._assert_range_constraint, lower=min_val, upper=max_val) - ) - messages.append(assert_msg) - elif isinstance(val, torch.Tensor): - for i, sym in enumerate(val.shape): - cbs, msgs = add_assertions(sym) - for cb, msg in zip(cbs, msgs): - def sym_size_cb(proxy, assert_msg, dim): - dim_proxy = super(_AddRuntimeAssertionsForConstraintsPass, self).call_operator( - torch.ops.aten.sym_size.int, - (proxy, dim), - {}, - self._create_dummy_node_metadata(), - ) - cb(proxy=dim_proxy, assert_msg=assert_msg) - call_backs.append(partial(sym_size_cb, dim=i)) - messages.append(f".shape[{i}]" + msg) - return call_backs, messages - callbacks, messages = add_assertions(val) - for cb, msg in zip(callbacks, messages): - cb(proxy=ret, assert_msg=f"{ret.node}" + msg) - return ret - - def _assert_range_constraint(self, proxy, lower, upper, assert_msg): - if lower > -math.inf: - self._insert_assert_async(operator.ge, proxy, lower, assert_msg) - - if upper < math.inf: - self._insert_assert_async(operator.le, proxy, upper, assert_msg) - - def _insert_assert_async(self, operator, lower, upper, assert_msg): - """ - Inserts assert_async call_function nodes in the graph. This function is - called **during** the interpreter-based pass. - """ - cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata()) - cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata()) - super().call_operator( - torch.ops.aten._assert_async.msg, - (cmp_tensor, assert_msg), - {}, - self._create_dummy_node_metadata(), - ) diff --git a/torch/_export/serde/upgrade.py b/torch/_export/serde/upgrade.py index c6e2d413414..d9afaf0a078 100644 --- a/torch/_export/serde/upgrade.py +++ b/torch/_export/serde/upgrade.py @@ -196,7 +196,7 @@ class GraphModuleOpUpgrader: upgraded_program = exported_program.transform(_pass) # NB: we have to retrace the graph_module instead of ep because of some failure. Also, we need to turn of # _add_runtime_assertions because dynamo is not happy with sym_size.int. - exported_program = export(upgraded_program.graph_module, inputs, {}, _add_runtime_assertions=False) + exported_program = export(upgraded_program.graph_module, inputs, {}) exported_program.call_spec = upgraded_program.call_spec return exported_program