[export] Remove experimental runtime assertion configs from export API. (#105043)

Test Plan: CI

Differential Revision: D47390794

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105043
Approved by: https://github.com/larryliu0820
This commit is contained in:
Zhengxu Chen 2023-07-26 16:21:29 +00:00 committed by PyTorch MergeBot
parent 8af25cfc24
commit 2dbadd1eae
8 changed files with 170 additions and 409 deletions

View file

@ -14,7 +14,6 @@ from torch.testing import FileCheck
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import export, dynamic_dim
from torch._export.constraints import constrain_as_value, constrain_as_size
from torch._export.exported_program import ExportGraphSignature
from torch._export.passes import (
ReplaceViewOpsWithViewCopyOpsPass,
)
@ -27,8 +26,7 @@ from torch._export.passes.functionalize_side_effectful_ops_pass import (
)
from functorch.experimental.control_flow import cond
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx._symbolic_trace import symbolic_trace
from torch.fx.passes.infra.partitioner import Partition
from torch.utils._pytree import tree_flatten
@ -94,12 +92,6 @@ class TestPasses(TestCase):
ep = export(M(), (x,), constraints=[dynamic_dim(x, 1) >= 2, dynamic_dim(x, 1) <= 6])
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
self.assertEqual(num_assert, 3)
self.assertEqual(num_scalar_tensor, 3)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
ep(torch.zeros(2, 7, 3))
@ -125,12 +117,6 @@ class TestPasses(TestCase):
ep = export(M(), (x, y), constraints=constraints)
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
self.assertEqual(num_assert, 6)
self.assertEqual(num_scalar_tensor, 6)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
@ -156,13 +142,6 @@ class TestPasses(TestCase):
ep = export(M(), (x, y), constraints=constraints)
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
# there are 3 asserts from y and 2 from dynamic x dims and 1 from static x dim
self.assertEqual(num_assert, 6)
self.assertEqual(num_scalar_tensor, 6)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
@ -194,13 +173,6 @@ class TestPasses(TestCase):
ep = export(M(), (x, y), constraints=constraints)
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
# there are 4 asserts from y and 3 from x
self.assertEqual(num_assert, 7)
self.assertEqual(num_scalar_tensor, 7)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
@ -278,12 +250,6 @@ class TestPasses(TestCase):
mod = M()
ep = export(mod, (x,))
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
# 1 constraint for shape of x, 2 constraints for b
self.assertEqual(num_assert, 3)
self.assertEqual(num_scalar_tensor, 3)
with self.assertRaisesRegex(RuntimeError, r"_local_scalar_dense_default is outside of inline constraint \[2, 5\]."):
ep(torch.tensor([6]))
@ -344,6 +310,8 @@ class TestPasses(TestCase):
y = torch.tensor([5])
mod = M()
ep = export(mod, (torch.tensor(True), x, y))
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."):
ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))
@ -409,239 +377,6 @@ class TestPasses(TestCase):
"torch.ops.aten.sym_constrain_range.default", 0, exactly=True
).run(gm.code)
dep_token_node = next(n for n in gm.graph.nodes if n.name == "dep_token3")
constrain_node = next(
n
for n in gm.graph.nodes
if n.target == torch.ops.aten._functional_sym_constrain_range
)
self.assertEqual(constrain_node.kwargs["dep_token"], dep_token_node)
def test_functionalize_input_constraints(self) -> None:
def f(x):
return x * 2
inp = torch.zeros(4, 8)
ep = torch._export.export(
f,
(inp,),
constraints=[
dynamic_dim(inp, 0) < 10,
dynamic_dim(inp, 0) >= 3,
],
)
FileCheck().check_count(
"torch.ops.aten._assert_async.msg", 3, exactly=True
).run(ep.graph_module.code)
gm = ep.transform(_FunctionalizeSideEffectfulOpsPass()).graph_module
with self.assertRaisesRegex(
RuntimeError,
r"Input arg0_1.shape\[0\] is outside of specified dynamic range \[3, 9\]",
):
gm(torch.ones(11, 8))
inp = torch.ones(6, 8)
self.assertEqual(gm(inp)[0], f(inp))
FileCheck().check_count(
"torch.ops.aten._functional_assert_async.msg", 3, exactly=True
).run(gm.code)
FileCheck().check_count(
"torch.ops.aten._assert_async.msg", 0, exactly=True
).run(gm.code)
def test_functionalization(self) -> None:
def f(x, y):
a = x.item()
constrain_as_size(a, 4, 7)
return x + 4, x + y * 2
inps = (torch.tensor([5]), torch.zeros((3, 4)))
ep = torch._export.export(
f,
inps,
constraints=[dynamic_dim(inps[1], 1) < 6],
_functionalize_runtime_assertions=True,
)
FileCheck().check_count(
"torch.ops.aten._functional_sym_constrain_range", 1, exactly=True
).run(ep.graph_module.code)
inps = (torch.tensor([7]), torch.ones((3, 5)))
self.assertTrue(torch._dynamo.utils.same(ep(*inps), f(*inps)))
def test_functionalization_with_native_python_assertion(self) -> None:
def f(x):
b = x.sin()
assert x[0] == 3
return x.cos() + b
inp = torch.Tensor([3, 4, 5])
ep = torch._export.export(f, (inp,), _functionalize_runtime_assertions=True)
# Check native assertion has corresponding functional assertion nodes generated.
select_int_node = next(
n
for n in ep.graph_module.graph.nodes
if n.target == torch.ops.aten.select.int
)
equal_scalar_node = select_int_node.next
dep_token_node = next(
n
for n in ep.graph_module.graph.nodes
if (
n.target == torch.ops.aten._functional_assert_async.msg
and n.args[0] == equal_scalar_node
)
)
self.assertIn(
"call_function[target=torch.ops.aten._functional_assert_async.msg]"
"(args = (%eq_scalar, assertion error), kwargs = {dep_token: %dep_token1}",
dep_token_node.format_node(),
)
def test_functionalization_with_mutated_buffer(self) -> None:
buf = torch.ones(6, 2)
weight = 0.01
bias = 0.2
d_in = 3
d_out = 4
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("buf", buf)
self.linear = torch.nn.Linear(d_in, d_out)
self.linear.weight.data.fill_(weight)
self.linear.bias.data.fill_(bias)
def forward(self, x):
self.buf.add_(5)
return self.linear(x).cos() + self.buf.sum()
inp = torch.ones(4, 3)
ep = torch._export.export(
Foo(),
(inp,),
constraints=[dynamic_dim(inp, 0) >= 3],
_functionalize_runtime_assertions=True,
)
gs = ep.graph_signature
self.assertEqual(
gs,
ExportGraphSignature(
parameters=["L__self___linear.weight", "L__self___linear.bias"],
buffers=["L__self___buf"],
user_inputs=["arg3_1"],
user_outputs=["add_tensor_1"],
inputs_to_parameters={
"arg0_1": "L__self___linear.weight",
"arg1_1": "L__self___linear.bias",
},
inputs_to_buffers={"arg2_1": "L__self___buf"},
buffers_to_mutate={"add_tensor": "L__self___buf"},
backward_signature=None,
assertion_dep_token={2: "dep_token7"},
),
)
outputs = next(n for n in ep.graph.nodes if n.op == "output").args[0]
self.assertEqual(
[str(o) for o in outputs],
["add_tensor", "add_tensor_1", "dep_token7"],
)
self.assertEqual(
len(outputs), len(gs.buffers_to_mutate) + len(gs.user_outputs) + 1,
)
inp = torch.randn(5, 3)
self.assertTrue(
torch._dynamo.utils.same(
# Directly check run output of `ep.graph_module` which is
# functionalized.
ep.graph_module(
torch.full((d_out, d_in), weight),
torch.full((d_out,), bias),
buf.clone(),
inp,
),
(buf.add(5), Foo()(inp), torch.empty(0)),
)
)
self.assertTrue(torch._dynamo.utils.same(ep(inp), Foo()(inp)))
def test_graph_partition_after_assertion_functionalization(self) -> None:
def f1(a, b):
add = a + b
add_1 = add + b
add_2 = add_1 + add
relu_1 = add_2.relu() # blocked by this
add_3 = add_2 + relu_1
add_4 = add_2 + add_3
return add_4, add_2
partitioner1 = CapabilityBasedPartitioner(
graph_module=symbolic_trace(f1),
operator_support=_AddOperatorSupport(),
)
partitions1 = partitioner1.propose_partitions()
self.assertEqual(
_to_partition_names(partitions1),
[{"add_3", "add_4"}, {"add", "add_1", "add_2"}],
)
def f2(a, b):
add = a + b
add_1 = add + b
add_2 = add_1 + add
assert add_1[0] == 5
relu_1 = add_2.relu() # blocked by this
add_3 = add_2 + relu_1
add_4 = add_2 + add_3
return add_4, add_2
inps = (torch.tensor([1, 3, 2]), torch.tensor([2, 3, 4]))
gm = export(
f2,
inps,
constraints=[dynamic_dim(inps[0], 0) == dynamic_dim(inps[1], 0)],
_functionalize_runtime_assertions=True,
).graph_module
partitioner2 = CapabilityBasedPartitioner(
graph_module=gm,
operator_support=_AtenAddOperatorSupport(),
)
partitions2 = partitioner2.propose_partitions()
self.assertEqual(
_to_partition_names(partitions2),
[
{"add_tensor_3", "add_tensor_4"},
{"add_tensor_1", "add_tensor_2", "add_tensor"},
]
)
fused_gm1 = partitioner1.fuse_partitions(partitions1)
fused_gm2 = partitioner2.fuse_partitions(partitions2)
inps = (torch.tensor([1, 4, 6]), torch.tensor([2, 4, 6]))
self.assertTrue(
torch._dynamo.utils.same(fused_gm1(*inps)[0], fused_gm2(*inps)[0]),
)
# Sub-module `fused_1` is for logic `add = ..., ..., add_2 = ...`
output_names1 = _get_output_names(fused_gm1.get_submodule("fused_1"))
output_names2 = _get_output_names(fused_gm2.get_submodule("fused_1"))
self.assertEqual(output_names1, ["add_2"])
# The extra output `add_tensor_1` is consumed by assertion.
self.assertEqual(output_names2, ["add_tensor_1", "add_tensor_2"])
if __name__ == '__main__':
run_tests()

View file

@ -199,39 +199,34 @@ class TestDeserialize(TestCase):
self.assertEqual(len(ep.graph.nodes), len(deserialized_ep.graph.nodes))
for node1, node2 in zip(ep.graph.nodes, deserialized_ep.graph.nodes):
# Check "val" metadata
val1 = node1.meta.get("val", None)
val2 = node2.meta.get("val", None)
self.assertEqual(node1.op, node2.op)
if node1.op == "call_function":
# Check "val" metadata
val1 = node1.meta.get("val", None)
val2 = node2.meta.get("val", None)
if val1 is None or val2 is None:
# Either both are None
self.assertEqual(val1, val2)
elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor):
# Or both are fake tensors with the same shape/dtype
self.assertEqual(len(val1.shape), len(val2.shape))
for s1, s2 in zip(val1.shape, val2.shape):
if is_concrete_int(s1) and is_concrete_int(s2):
self.assertEqual(s1, s2)
else:
self.assertEqual(str(s1), str(s2))
self.assertEqual(val1.dtype, val2.dtype)
elif isinstance(val1, list) and isinstance(val2, list):
# Or both are fake tensors lists with one element and with the
# same shape/dtype
self.assertTrue(len(val1) == 1 and len(val2) == 1)
self.assertEqual(val1[0].shape, val2[0].shape)
self.assertEqual(val1[0].dtype, val2[0].dtype)
else:
# For expressions like 's0 < 10' can only compare through string
self.assertEqual(str(val1), str(val2))
if val1 is None or val2 is None:
# Either both are None
self.assertEqual(val1, val2)
elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor):
# Or both are fake tensors with the same shape/dtype
self.assertEqual(len(val1.shape), len(val2.shape))
for s1, s2 in zip(val1.shape, val2.shape):
if is_concrete_int(s1) and is_concrete_int(s2):
self.assertEqual(s1, s2)
else:
self.assertEqual(str(s1), str(s2))
self.assertEqual(val1.dtype, val2.dtype)
elif isinstance(val1, list) and isinstance(val2, list):
# Or both are fake tensors lists with one element and with the
# same shape/dtype
self.assertTrue(len(val1) == 1 and len(val2) == 1)
self.assertEqual(val1[0].shape, val2[0].shape)
self.assertEqual(val1[0].dtype, val2[0].dtype)
else:
# For expressions like 's0 < 10' can only compare through string
self.assertEqual(str(val1), str(val2))
# Check "stack_trace" metadata
if "None" in node1.meta.get("stack_trace"):
self.assertTrue(
node2.meta.get("stack_trace") is None
or "None" in node2.meta.get("stack_trace")
)
else:
# Check "stack_trace" metadata
self.assertEqual(
node1.meta.get("stack_trace", None),
node2.meta.get("stack_trace", None),

View file

@ -117,7 +117,7 @@ def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Opti
return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc')
inputs = (torch.ones([2, 3]) * 4, 2.)
ep = export(fn, inputs, {}, [], _add_runtime_assertions=False)
ep = export(fn, inputs, {}, [])
compiler_opset_version = {"aten": 4}
model_opset_version = {"aten": 3}
upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)

View file

@ -59,7 +59,6 @@ class TestOutDtypeOp(TestCase):
ep = torch._export.export(
m,
(x,),
_add_runtime_assertions=False,
)
FileCheck().check("torch.ops.higher_order.out_dtype").check("aten.mm.default").run(ep.graph_module.code)
self.assertTrue(torch.allclose(m(x), ep(x)))

View file

@ -40,7 +40,7 @@ from .exported_program import (
ExportGraphSignature,
)
from .passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
from .passes.add_runtime_assertions_for_constraints_pass import _AddRuntimeAssertionsForInlineConstraintsPass
# Note - [On Export Dynamic Dimension UX]
#
@ -127,9 +127,6 @@ def export(
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
constraints: Optional[List[Constraint]] = None,
*,
_add_runtime_assertions=True,
_functionalize_runtime_assertions=False,
) -> ExportedProgram:
"""
Traces either an nn.Module's forward function or just a callable with PyTorch
@ -324,11 +321,9 @@ def export(
equality_constraints,
)
if _add_runtime_assertions:
exported_program = exported_program._add_runtime_assertions(
functionalize=_functionalize_runtime_assertions,
)
exported_program = exported_program.transform(
_AddRuntimeAssertionsForInlineConstraintsPass(range_constraints, equality_constraints)
)
return exported_program.transform(_ReplaceSymSizeOpPass())
except (ConstraintViolationError, ValueRangeError) as e:

View file

@ -1,16 +1,19 @@
from collections import defaultdict
import copy
import dataclasses
import sympy
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union
from torch._functorch.aot_autograd import FQN, GraphInputName, GraphOutputName
import sympy
import torch
from torch.fx.passes.infra.pass_manager import PassManager
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.fx.experimental.symbolic_shapes import SymInt
from torch import fx
from torch._functorch.aot_autograd import FQN, GraphInputName, GraphOutputName
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.symbolic_shapes import SymInt
from torch.fx.passes.infra.pass_manager import PassManager
from . import error
from .pass_base import PassType
from .passes.add_runtime_assertions_for_constraints_pass import (
@ -122,7 +125,8 @@ class ExportedProgram:
f"{received_spec}"
)
param_buffer_values = (value for _, value in self.state_dict.items())
param_buffer_values = tuple(value for _, value in self.state_dict.items())
self._check_input_constraints(*param_buffer_values, *args)
with torch.no_grad():
res = torch.fx.Interpreter(self.graph_module).run(
@ -208,6 +212,24 @@ class ExportedProgram:
transformed_ep.graph_module.meta.update(res.graph_module.meta)
return transformed_ep
def _check_input_constraints(self, *args):
# TODO(zhxchen17) Remove _add_runtime_assertions.
# TODO(zhxchen17) Don't generate a runtime graph on the fly.
_assertion_graph = fx.GraphModule({}, fx.Graph())
for p in self.graph.nodes:
if p.op != "placeholder":
continue
new_p = _assertion_graph.graph.placeholder(p.name)
new_p.meta = p.meta
_assertion_graph.graph.output(())
_assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass(
self.range_constraints,
self.equality_constraints,
)(_assertion_graph)
assert _assertion_graph_res is not None
_assertion_graph = _assertion_graph_res.graph_module
_assertion_graph(*args)
def _add_runtime_assertions(
self,
functionalize: bool,

View file

@ -50,7 +50,7 @@ def _convert_range_to_int(range: RangeConstraint):
return min_val, max_val
class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
class _AddRuntimeAssertionsForInlineConstraintsPass(ExportPassBase):
def __init__(
self,
range_constraints: Dict[sympy.Symbol, RangeConstraint],
@ -60,6 +60,101 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
self.range_constraints: Dict[sympy.Symbol, RangeConstraint] = range_constraints
self.equality_constraints: List[Tuple[InputDim, InputDim]] = equality_constraints
def _assert_range_constraint(self, proxy, lower, upper, assert_msg):
if lower > -math.inf:
self._insert_assert_async(operator.ge, proxy, lower, assert_msg)
if upper < math.inf:
self._insert_assert_async(operator.le, proxy, upper, assert_msg)
def _insert_assert_async(self, operator, lower, upper, assert_msg):
"""
Inserts assert_async call_function nodes in the graph. This function is
called **during** the interpreter-based pass.
"""
cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata())
cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata())
super().call_operator(
torch.ops.aten._assert_async.msg,
(cmp_tensor, assert_msg),
{},
self._create_dummy_node_metadata(),
)
def call_operator(self, op, args, kwargs, meta) -> ProxyValue:
ret = super().call_operator(op, args, kwargs, meta)
if "val" not in meta:
return ret
val = meta["val"]
# In general, we may have to deal the case such as: ret[1].shape[0].
# We need first find out what symbols require assertion, then we need to follow the path
# from ret to the symbol, construct the proxies along the way and construct the messages
# piece-wise at the same time.
#
# We use post-order traversal to collect all the proxies callbacks needed, construct
# the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
# We need the callbacks because, in order to call the function to create a proxy for shape[0], we
# need the proxy for shape, which further requries the proxy for ret[1], etc.
def add_assertions(val):
call_backs = []
messages = []
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
symbol = val.node._expr
if isinstance(symbol, sympy.Symbol) and symbol.name.startswith("i"):
# We only care about unbacked symints for these inline
# constraints, which are prefixed with 'i'
constraint = self.range_constraints[symbol]
min_val, max_val = _convert_range_to_int(constraint)
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
call_backs.append(
partial(self._assert_range_constraint, lower=min_val, upper=max_val)
)
messages.append(assert_msg)
elif isinstance(val, torch.Tensor):
for i, sym in enumerate(val.shape):
cbs, msgs = add_assertions(sym)
for cb, msg in zip(cbs, msgs):
def sym_size_cb(proxy, assert_msg, dim):
dim_proxy = super(
_AddRuntimeAssertionsForInlineConstraintsPass,
self
).call_operator(
torch.ops.aten.sym_size.int,
(proxy, dim),
{},
self._create_dummy_node_metadata(),
)
cb(proxy=dim_proxy, assert_msg=assert_msg)
call_backs.append(partial(sym_size_cb, dim=i))
messages.append(f".shape[{i}]" + msg)
return call_backs, messages
callbacks, messages = add_assertions(val)
for cb, msg in zip(callbacks, messages):
cb(proxy=ret, assert_msg=f"{ret.node}" + msg)
return ret
def call(self, graph_module):
# Add runtime asserts for inline constraints
val = super().call(graph_module)
# Populate the stack trace with dummy vals to respect IR
for node in val.graph_module.graph.nodes:
if not hasattr(node.meta, "stack_trace"):
node.meta["stack_trace"] = traceback.format_exc(-1)
return PassResult(val.graph_module, val.modified)
class _AddRuntimeAssertionsForConstraintsPass(_AddRuntimeAssertionsForInlineConstraintsPass):
def __init__(
self,
range_constraints: Dict[sympy.Symbol, RangeConstraint],
equality_constraints: List[Tuple[InputDim, InputDim]],
):
super().__init__(range_constraints, equality_constraints)
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module = copy.deepcopy(graph_module)
graph = graph_module.graph
@ -98,11 +193,10 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
if isinstance(shape, SymInt):
# If the shape is dynamic, add range assertions
symbol = shape.node._expr
assert symbol in self.range_constraints
self._insert_range_assert_inplace(
graph, input_dim, dim_node, self.range_constraints[symbol]
)
if symbol in self.range_constraints:
self._insert_range_assert_inplace(
graph, input_dim, dim_node, self.range_constraints[symbol]
)
else:
# If no dynamism is specified, we assume all dimensions #
# are specialized
@ -112,20 +206,13 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
)
# Add runtime assertions on equality constraints on the inputs
with graph.inserting_after(
list(inputdim_to_node.values())[-1]
):
self._insert_equality_assert_inplace(graph, inputdim_to_node)
if len(inputdim_to_node) > 0:
with graph.inserting_after(
list(inputdim_to_node.values())[-1]
):
self._insert_equality_assert_inplace(graph, inputdim_to_node)
# Add runtime asserts for inline constraints
val = super().call(graph_module)
# Populate the stack trace with dummy vals to respect IR
for node in val.graph_module.graph.nodes:
if not hasattr(node.meta, "stack_trace"):
node.meta["stack_trace"] = traceback.format_exc(-1)
return PassResult(val.graph_module, val.modified)
return super().call(graph_module)
def _insert_specialized_shape_assert_inplace(
self, graph: torch.fx.Graph, input_dim: InputDim, dim_node: torch.fx.Node, shape: int,
@ -202,75 +289,3 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
_ = graph.call_function(
torch.ops.aten._assert_async.msg, (cmp_tensor_node, assert_msg)
)
def call_operator(self, op, args, kwargs, meta) -> ProxyValue:
ret = super().call_operator(op, args, kwargs, meta)
if "val" not in meta:
return ret
val = meta["val"]
# In general, we may have to deal the case such as: ret[1].shape[0].
# We need first find out what symbols require assertion, then we need to follow the path
# from ret to the symbol, construct the proxies along the way and construct the messages
# piece-wise at the same time.
#
# We use post-order traversal to collect all the proxies callbacks needed, construct
# the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
# We need the callbacks because, in order to call the function to create a proxy for shape[0], we
# need the proxy for shape, which further requries the proxy for ret[1], etc.
def add_assertions(val):
call_backs = []
messages = []
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
symbol = val.node._expr
if isinstance(symbol, sympy.Symbol) and symbol.name.startswith("i"):
# We only care about unbacked symints for these inline
# constraints, which are prefixed with 'i'
constraint = self.range_constraints[symbol]
min_val, max_val = _convert_range_to_int(constraint)
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
call_backs.append(
partial(self._assert_range_constraint, lower=min_val, upper=max_val)
)
messages.append(assert_msg)
elif isinstance(val, torch.Tensor):
for i, sym in enumerate(val.shape):
cbs, msgs = add_assertions(sym)
for cb, msg in zip(cbs, msgs):
def sym_size_cb(proxy, assert_msg, dim):
dim_proxy = super(_AddRuntimeAssertionsForConstraintsPass, self).call_operator(
torch.ops.aten.sym_size.int,
(proxy, dim),
{},
self._create_dummy_node_metadata(),
)
cb(proxy=dim_proxy, assert_msg=assert_msg)
call_backs.append(partial(sym_size_cb, dim=i))
messages.append(f".shape[{i}]" + msg)
return call_backs, messages
callbacks, messages = add_assertions(val)
for cb, msg in zip(callbacks, messages):
cb(proxy=ret, assert_msg=f"{ret.node}" + msg)
return ret
def _assert_range_constraint(self, proxy, lower, upper, assert_msg):
if lower > -math.inf:
self._insert_assert_async(operator.ge, proxy, lower, assert_msg)
if upper < math.inf:
self._insert_assert_async(operator.le, proxy, upper, assert_msg)
def _insert_assert_async(self, operator, lower, upper, assert_msg):
"""
Inserts assert_async call_function nodes in the graph. This function is
called **during** the interpreter-based pass.
"""
cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata())
cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata())
super().call_operator(
torch.ops.aten._assert_async.msg,
(cmp_tensor, assert_msg),
{},
self._create_dummy_node_metadata(),
)

View file

@ -196,7 +196,7 @@ class GraphModuleOpUpgrader:
upgraded_program = exported_program.transform(_pass)
# NB: we have to retrace the graph_module instead of ep because of some failure. Also, we need to turn of
# _add_runtime_assertions because dynamo is not happy with sym_size.int.
exported_program = export(upgraded_program.graph_module, inputs, {}, _add_runtime_assertions=False)
exported_program = export(upgraded_program.graph_module, inputs, {})
exported_program.call_spec = upgraded_program.call_spec
return exported_program