mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[export] Remove experimental runtime assertion configs from export API. (#105043)
Test Plan: CI Differential Revision: D47390794 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105043 Approved by: https://github.com/larryliu0820
This commit is contained in:
parent
8af25cfc24
commit
2dbadd1eae
8 changed files with 170 additions and 409 deletions
|
|
@ -14,7 +14,6 @@ from torch.testing import FileCheck
|
|||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
from torch._export import export, dynamic_dim
|
||||
from torch._export.constraints import constrain_as_value, constrain_as_size
|
||||
from torch._export.exported_program import ExportGraphSignature
|
||||
from torch._export.passes import (
|
||||
ReplaceViewOpsWithViewCopyOpsPass,
|
||||
)
|
||||
|
|
@ -27,8 +26,7 @@ from torch._export.passes.functionalize_side_effectful_ops_pass import (
|
|||
)
|
||||
from functorch.experimental.control_flow import cond
|
||||
from torch.fx.passes.operator_support import OperatorSupport
|
||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
|
||||
from torch.fx._symbolic_trace import symbolic_trace
|
||||
from torch.fx.passes.infra.partitioner import Partition
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
|
|
@ -94,12 +92,6 @@ class TestPasses(TestCase):
|
|||
|
||||
ep = export(M(), (x,), constraints=[dynamic_dim(x, 1) >= 2, dynamic_dim(x, 1) <= 6])
|
||||
|
||||
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
|
||||
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
|
||||
|
||||
self.assertEqual(num_assert, 3)
|
||||
self.assertEqual(num_scalar_tensor, 3)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
ep(torch.zeros(2, 7, 3))
|
||||
|
||||
|
|
@ -125,12 +117,6 @@ class TestPasses(TestCase):
|
|||
|
||||
ep = export(M(), (x, y), constraints=constraints)
|
||||
|
||||
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
|
||||
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
|
||||
|
||||
self.assertEqual(num_assert, 6)
|
||||
self.assertEqual(num_scalar_tensor, 6)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
|
|
@ -156,13 +142,6 @@ class TestPasses(TestCase):
|
|||
|
||||
ep = export(M(), (x, y), constraints=constraints)
|
||||
|
||||
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
|
||||
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
|
||||
|
||||
# there are 3 asserts from y and 2 from dynamic x dims and 1 from static x dim
|
||||
self.assertEqual(num_assert, 6)
|
||||
self.assertEqual(num_scalar_tensor, 6)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
|
|
@ -194,13 +173,6 @@ class TestPasses(TestCase):
|
|||
|
||||
ep = export(M(), (x, y), constraints=constraints)
|
||||
|
||||
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
|
||||
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
|
||||
|
||||
# there are 4 asserts from y and 3 from x
|
||||
self.assertEqual(num_assert, 7)
|
||||
self.assertEqual(num_scalar_tensor, 7)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
|
|
@ -278,12 +250,6 @@ class TestPasses(TestCase):
|
|||
mod = M()
|
||||
ep = export(mod, (x,))
|
||||
|
||||
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
|
||||
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
|
||||
# 1 constraint for shape of x, 2 constraints for b
|
||||
self.assertEqual(num_assert, 3)
|
||||
self.assertEqual(num_scalar_tensor, 3)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"_local_scalar_dense_default is outside of inline constraint \[2, 5\]."):
|
||||
ep(torch.tensor([6]))
|
||||
|
||||
|
|
@ -344,6 +310,8 @@ class TestPasses(TestCase):
|
|||
y = torch.tensor([5])
|
||||
mod = M()
|
||||
ep = export(mod, (torch.tensor(True), x, y))
|
||||
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."):
|
||||
ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))
|
||||
|
||||
|
|
@ -409,239 +377,6 @@ class TestPasses(TestCase):
|
|||
"torch.ops.aten.sym_constrain_range.default", 0, exactly=True
|
||||
).run(gm.code)
|
||||
|
||||
dep_token_node = next(n for n in gm.graph.nodes if n.name == "dep_token3")
|
||||
constrain_node = next(
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if n.target == torch.ops.aten._functional_sym_constrain_range
|
||||
)
|
||||
self.assertEqual(constrain_node.kwargs["dep_token"], dep_token_node)
|
||||
|
||||
def test_functionalize_input_constraints(self) -> None:
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
inp = torch.zeros(4, 8)
|
||||
ep = torch._export.export(
|
||||
f,
|
||||
(inp,),
|
||||
constraints=[
|
||||
dynamic_dim(inp, 0) < 10,
|
||||
dynamic_dim(inp, 0) >= 3,
|
||||
],
|
||||
)
|
||||
FileCheck().check_count(
|
||||
"torch.ops.aten._assert_async.msg", 3, exactly=True
|
||||
).run(ep.graph_module.code)
|
||||
|
||||
gm = ep.transform(_FunctionalizeSideEffectfulOpsPass()).graph_module
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Input arg0_1.shape\[0\] is outside of specified dynamic range \[3, 9\]",
|
||||
):
|
||||
gm(torch.ones(11, 8))
|
||||
|
||||
inp = torch.ones(6, 8)
|
||||
self.assertEqual(gm(inp)[0], f(inp))
|
||||
FileCheck().check_count(
|
||||
"torch.ops.aten._functional_assert_async.msg", 3, exactly=True
|
||||
).run(gm.code)
|
||||
FileCheck().check_count(
|
||||
"torch.ops.aten._assert_async.msg", 0, exactly=True
|
||||
).run(gm.code)
|
||||
|
||||
def test_functionalization(self) -> None:
|
||||
def f(x, y):
|
||||
a = x.item()
|
||||
constrain_as_size(a, 4, 7)
|
||||
return x + 4, x + y * 2
|
||||
|
||||
inps = (torch.tensor([5]), torch.zeros((3, 4)))
|
||||
ep = torch._export.export(
|
||||
f,
|
||||
inps,
|
||||
constraints=[dynamic_dim(inps[1], 1) < 6],
|
||||
_functionalize_runtime_assertions=True,
|
||||
)
|
||||
FileCheck().check_count(
|
||||
"torch.ops.aten._functional_sym_constrain_range", 1, exactly=True
|
||||
).run(ep.graph_module.code)
|
||||
inps = (torch.tensor([7]), torch.ones((3, 5)))
|
||||
self.assertTrue(torch._dynamo.utils.same(ep(*inps), f(*inps)))
|
||||
|
||||
def test_functionalization_with_native_python_assertion(self) -> None:
|
||||
def f(x):
|
||||
b = x.sin()
|
||||
assert x[0] == 3
|
||||
return x.cos() + b
|
||||
|
||||
inp = torch.Tensor([3, 4, 5])
|
||||
ep = torch._export.export(f, (inp,), _functionalize_runtime_assertions=True)
|
||||
|
||||
# Check native assertion has corresponding functional assertion nodes generated.
|
||||
select_int_node = next(
|
||||
n
|
||||
for n in ep.graph_module.graph.nodes
|
||||
if n.target == torch.ops.aten.select.int
|
||||
)
|
||||
equal_scalar_node = select_int_node.next
|
||||
dep_token_node = next(
|
||||
n
|
||||
for n in ep.graph_module.graph.nodes
|
||||
if (
|
||||
n.target == torch.ops.aten._functional_assert_async.msg
|
||||
and n.args[0] == equal_scalar_node
|
||||
)
|
||||
)
|
||||
self.assertIn(
|
||||
"call_function[target=torch.ops.aten._functional_assert_async.msg]"
|
||||
"(args = (%eq_scalar, assertion error), kwargs = {dep_token: %dep_token1}",
|
||||
dep_token_node.format_node(),
|
||||
)
|
||||
|
||||
def test_functionalization_with_mutated_buffer(self) -> None:
|
||||
buf = torch.ones(6, 2)
|
||||
weight = 0.01
|
||||
bias = 0.2
|
||||
d_in = 3
|
||||
d_out = 4
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer("buf", buf)
|
||||
|
||||
self.linear = torch.nn.Linear(d_in, d_out)
|
||||
self.linear.weight.data.fill_(weight)
|
||||
self.linear.bias.data.fill_(bias)
|
||||
|
||||
def forward(self, x):
|
||||
self.buf.add_(5)
|
||||
return self.linear(x).cos() + self.buf.sum()
|
||||
|
||||
inp = torch.ones(4, 3)
|
||||
ep = torch._export.export(
|
||||
Foo(),
|
||||
(inp,),
|
||||
constraints=[dynamic_dim(inp, 0) >= 3],
|
||||
_functionalize_runtime_assertions=True,
|
||||
)
|
||||
|
||||
gs = ep.graph_signature
|
||||
self.assertEqual(
|
||||
gs,
|
||||
ExportGraphSignature(
|
||||
parameters=["L__self___linear.weight", "L__self___linear.bias"],
|
||||
buffers=["L__self___buf"],
|
||||
user_inputs=["arg3_1"],
|
||||
user_outputs=["add_tensor_1"],
|
||||
inputs_to_parameters={
|
||||
"arg0_1": "L__self___linear.weight",
|
||||
"arg1_1": "L__self___linear.bias",
|
||||
},
|
||||
inputs_to_buffers={"arg2_1": "L__self___buf"},
|
||||
buffers_to_mutate={"add_tensor": "L__self___buf"},
|
||||
backward_signature=None,
|
||||
assertion_dep_token={2: "dep_token7"},
|
||||
),
|
||||
)
|
||||
outputs = next(n for n in ep.graph.nodes if n.op == "output").args[0]
|
||||
self.assertEqual(
|
||||
[str(o) for o in outputs],
|
||||
["add_tensor", "add_tensor_1", "dep_token7"],
|
||||
)
|
||||
self.assertEqual(
|
||||
len(outputs), len(gs.buffers_to_mutate) + len(gs.user_outputs) + 1,
|
||||
)
|
||||
inp = torch.randn(5, 3)
|
||||
self.assertTrue(
|
||||
torch._dynamo.utils.same(
|
||||
# Directly check run output of `ep.graph_module` which is
|
||||
# functionalized.
|
||||
ep.graph_module(
|
||||
torch.full((d_out, d_in), weight),
|
||||
torch.full((d_out,), bias),
|
||||
buf.clone(),
|
||||
inp,
|
||||
),
|
||||
(buf.add(5), Foo()(inp), torch.empty(0)),
|
||||
)
|
||||
)
|
||||
self.assertTrue(torch._dynamo.utils.same(ep(inp), Foo()(inp)))
|
||||
|
||||
def test_graph_partition_after_assertion_functionalization(self) -> None:
|
||||
def f1(a, b):
|
||||
add = a + b
|
||||
add_1 = add + b
|
||||
add_2 = add_1 + add
|
||||
|
||||
relu_1 = add_2.relu() # blocked by this
|
||||
|
||||
add_3 = add_2 + relu_1
|
||||
add_4 = add_2 + add_3
|
||||
return add_4, add_2
|
||||
|
||||
partitioner1 = CapabilityBasedPartitioner(
|
||||
graph_module=symbolic_trace(f1),
|
||||
operator_support=_AddOperatorSupport(),
|
||||
)
|
||||
partitions1 = partitioner1.propose_partitions()
|
||||
|
||||
self.assertEqual(
|
||||
_to_partition_names(partitions1),
|
||||
[{"add_3", "add_4"}, {"add", "add_1", "add_2"}],
|
||||
)
|
||||
|
||||
def f2(a, b):
|
||||
add = a + b
|
||||
add_1 = add + b
|
||||
add_2 = add_1 + add
|
||||
|
||||
assert add_1[0] == 5
|
||||
|
||||
relu_1 = add_2.relu() # blocked by this
|
||||
|
||||
add_3 = add_2 + relu_1
|
||||
add_4 = add_2 + add_3
|
||||
return add_4, add_2
|
||||
|
||||
inps = (torch.tensor([1, 3, 2]), torch.tensor([2, 3, 4]))
|
||||
gm = export(
|
||||
f2,
|
||||
inps,
|
||||
constraints=[dynamic_dim(inps[0], 0) == dynamic_dim(inps[1], 0)],
|
||||
_functionalize_runtime_assertions=True,
|
||||
).graph_module
|
||||
partitioner2 = CapabilityBasedPartitioner(
|
||||
graph_module=gm,
|
||||
operator_support=_AtenAddOperatorSupport(),
|
||||
)
|
||||
partitions2 = partitioner2.propose_partitions()
|
||||
|
||||
self.assertEqual(
|
||||
_to_partition_names(partitions2),
|
||||
[
|
||||
{"add_tensor_3", "add_tensor_4"},
|
||||
{"add_tensor_1", "add_tensor_2", "add_tensor"},
|
||||
]
|
||||
)
|
||||
|
||||
fused_gm1 = partitioner1.fuse_partitions(partitions1)
|
||||
fused_gm2 = partitioner2.fuse_partitions(partitions2)
|
||||
|
||||
inps = (torch.tensor([1, 4, 6]), torch.tensor([2, 4, 6]))
|
||||
self.assertTrue(
|
||||
torch._dynamo.utils.same(fused_gm1(*inps)[0], fused_gm2(*inps)[0]),
|
||||
)
|
||||
|
||||
# Sub-module `fused_1` is for logic `add = ..., ..., add_2 = ...`
|
||||
output_names1 = _get_output_names(fused_gm1.get_submodule("fused_1"))
|
||||
output_names2 = _get_output_names(fused_gm2.get_submodule("fused_1"))
|
||||
|
||||
self.assertEqual(output_names1, ["add_2"])
|
||||
# The extra output `add_tensor_1` is consumed by assertion.
|
||||
self.assertEqual(output_names2, ["add_tensor_1", "add_tensor_2"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -199,39 +199,34 @@ class TestDeserialize(TestCase):
|
|||
|
||||
self.assertEqual(len(ep.graph.nodes), len(deserialized_ep.graph.nodes))
|
||||
for node1, node2 in zip(ep.graph.nodes, deserialized_ep.graph.nodes):
|
||||
# Check "val" metadata
|
||||
val1 = node1.meta.get("val", None)
|
||||
val2 = node2.meta.get("val", None)
|
||||
self.assertEqual(node1.op, node2.op)
|
||||
if node1.op == "call_function":
|
||||
# Check "val" metadata
|
||||
val1 = node1.meta.get("val", None)
|
||||
val2 = node2.meta.get("val", None)
|
||||
if val1 is None or val2 is None:
|
||||
# Either both are None
|
||||
self.assertEqual(val1, val2)
|
||||
elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor):
|
||||
# Or both are fake tensors with the same shape/dtype
|
||||
self.assertEqual(len(val1.shape), len(val2.shape))
|
||||
for s1, s2 in zip(val1.shape, val2.shape):
|
||||
if is_concrete_int(s1) and is_concrete_int(s2):
|
||||
self.assertEqual(s1, s2)
|
||||
else:
|
||||
self.assertEqual(str(s1), str(s2))
|
||||
self.assertEqual(val1.dtype, val2.dtype)
|
||||
elif isinstance(val1, list) and isinstance(val2, list):
|
||||
# Or both are fake tensors lists with one element and with the
|
||||
# same shape/dtype
|
||||
self.assertTrue(len(val1) == 1 and len(val2) == 1)
|
||||
self.assertEqual(val1[0].shape, val2[0].shape)
|
||||
self.assertEqual(val1[0].dtype, val2[0].dtype)
|
||||
else:
|
||||
# For expressions like 's0 < 10' can only compare through string
|
||||
self.assertEqual(str(val1), str(val2))
|
||||
|
||||
if val1 is None or val2 is None:
|
||||
# Either both are None
|
||||
self.assertEqual(val1, val2)
|
||||
elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor):
|
||||
# Or both are fake tensors with the same shape/dtype
|
||||
self.assertEqual(len(val1.shape), len(val2.shape))
|
||||
for s1, s2 in zip(val1.shape, val2.shape):
|
||||
if is_concrete_int(s1) and is_concrete_int(s2):
|
||||
self.assertEqual(s1, s2)
|
||||
else:
|
||||
self.assertEqual(str(s1), str(s2))
|
||||
self.assertEqual(val1.dtype, val2.dtype)
|
||||
elif isinstance(val1, list) and isinstance(val2, list):
|
||||
# Or both are fake tensors lists with one element and with the
|
||||
# same shape/dtype
|
||||
self.assertTrue(len(val1) == 1 and len(val2) == 1)
|
||||
self.assertEqual(val1[0].shape, val2[0].shape)
|
||||
self.assertEqual(val1[0].dtype, val2[0].dtype)
|
||||
else:
|
||||
# For expressions like 's0 < 10' can only compare through string
|
||||
self.assertEqual(str(val1), str(val2))
|
||||
|
||||
# Check "stack_trace" metadata
|
||||
if "None" in node1.meta.get("stack_trace"):
|
||||
self.assertTrue(
|
||||
node2.meta.get("stack_trace") is None
|
||||
or "None" in node2.meta.get("stack_trace")
|
||||
)
|
||||
else:
|
||||
# Check "stack_trace" metadata
|
||||
self.assertEqual(
|
||||
node1.meta.get("stack_trace", None),
|
||||
node2.meta.get("stack_trace", None),
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Opti
|
|||
return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc')
|
||||
|
||||
inputs = (torch.ones([2, 3]) * 4, 2.)
|
||||
ep = export(fn, inputs, {}, [], _add_runtime_assertions=False)
|
||||
ep = export(fn, inputs, {}, [])
|
||||
compiler_opset_version = {"aten": 4}
|
||||
model_opset_version = {"aten": 3}
|
||||
upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ class TestOutDtypeOp(TestCase):
|
|||
ep = torch._export.export(
|
||||
m,
|
||||
(x,),
|
||||
_add_runtime_assertions=False,
|
||||
)
|
||||
FileCheck().check("torch.ops.higher_order.out_dtype").check("aten.mm.default").run(ep.graph_module.code)
|
||||
self.assertTrue(torch.allclose(m(x), ep(x)))
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from .exported_program import (
|
|||
ExportGraphSignature,
|
||||
)
|
||||
from .passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
|
||||
|
||||
from .passes.add_runtime_assertions_for_constraints_pass import _AddRuntimeAssertionsForInlineConstraintsPass
|
||||
|
||||
# Note - [On Export Dynamic Dimension UX]
|
||||
#
|
||||
|
|
@ -127,9 +127,6 @@ def export(
|
|||
args: Tuple[Any],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
constraints: Optional[List[Constraint]] = None,
|
||||
*,
|
||||
_add_runtime_assertions=True,
|
||||
_functionalize_runtime_assertions=False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
Traces either an nn.Module's forward function or just a callable with PyTorch
|
||||
|
|
@ -324,11 +321,9 @@ def export(
|
|||
equality_constraints,
|
||||
)
|
||||
|
||||
if _add_runtime_assertions:
|
||||
exported_program = exported_program._add_runtime_assertions(
|
||||
functionalize=_functionalize_runtime_assertions,
|
||||
)
|
||||
|
||||
exported_program = exported_program.transform(
|
||||
_AddRuntimeAssertionsForInlineConstraintsPass(range_constraints, equality_constraints)
|
||||
)
|
||||
return exported_program.transform(_ReplaceSymSizeOpPass())
|
||||
|
||||
except (ConstraintViolationError, ValueRangeError) as e:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,19 @@
|
|||
from collections import defaultdict
|
||||
import copy
|
||||
import dataclasses
|
||||
import sympy
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from torch._functorch.aot_autograd import FQN, GraphInputName, GraphOutputName
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
import torch.fx._pytree as fx_pytree
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx.experimental.symbolic_shapes import SymInt
|
||||
from torch import fx
|
||||
from torch._functorch.aot_autograd import FQN, GraphInputName, GraphOutputName
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.fx.experimental.symbolic_shapes import SymInt
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
|
||||
from . import error
|
||||
from .pass_base import PassType
|
||||
from .passes.add_runtime_assertions_for_constraints_pass import (
|
||||
|
|
@ -122,7 +125,8 @@ class ExportedProgram:
|
|||
f"{received_spec}"
|
||||
)
|
||||
|
||||
param_buffer_values = (value for _, value in self.state_dict.items())
|
||||
param_buffer_values = tuple(value for _, value in self.state_dict.items())
|
||||
self._check_input_constraints(*param_buffer_values, *args)
|
||||
|
||||
with torch.no_grad():
|
||||
res = torch.fx.Interpreter(self.graph_module).run(
|
||||
|
|
@ -208,6 +212,24 @@ class ExportedProgram:
|
|||
transformed_ep.graph_module.meta.update(res.graph_module.meta)
|
||||
return transformed_ep
|
||||
|
||||
def _check_input_constraints(self, *args):
|
||||
# TODO(zhxchen17) Remove _add_runtime_assertions.
|
||||
# TODO(zhxchen17) Don't generate a runtime graph on the fly.
|
||||
_assertion_graph = fx.GraphModule({}, fx.Graph())
|
||||
for p in self.graph.nodes:
|
||||
if p.op != "placeholder":
|
||||
continue
|
||||
new_p = _assertion_graph.graph.placeholder(p.name)
|
||||
new_p.meta = p.meta
|
||||
_assertion_graph.graph.output(())
|
||||
_assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass(
|
||||
self.range_constraints,
|
||||
self.equality_constraints,
|
||||
)(_assertion_graph)
|
||||
assert _assertion_graph_res is not None
|
||||
_assertion_graph = _assertion_graph_res.graph_module
|
||||
_assertion_graph(*args)
|
||||
|
||||
def _add_runtime_assertions(
|
||||
self,
|
||||
functionalize: bool,
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def _convert_range_to_int(range: RangeConstraint):
|
|||
return min_val, max_val
|
||||
|
||||
|
||||
class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
|
||||
class _AddRuntimeAssertionsForInlineConstraintsPass(ExportPassBase):
|
||||
def __init__(
|
||||
self,
|
||||
range_constraints: Dict[sympy.Symbol, RangeConstraint],
|
||||
|
|
@ -60,6 +60,101 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
|
|||
self.range_constraints: Dict[sympy.Symbol, RangeConstraint] = range_constraints
|
||||
self.equality_constraints: List[Tuple[InputDim, InputDim]] = equality_constraints
|
||||
|
||||
def _assert_range_constraint(self, proxy, lower, upper, assert_msg):
|
||||
if lower > -math.inf:
|
||||
self._insert_assert_async(operator.ge, proxy, lower, assert_msg)
|
||||
|
||||
if upper < math.inf:
|
||||
self._insert_assert_async(operator.le, proxy, upper, assert_msg)
|
||||
|
||||
def _insert_assert_async(self, operator, lower, upper, assert_msg):
|
||||
"""
|
||||
Inserts assert_async call_function nodes in the graph. This function is
|
||||
called **during** the interpreter-based pass.
|
||||
"""
|
||||
cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata())
|
||||
cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata())
|
||||
super().call_operator(
|
||||
torch.ops.aten._assert_async.msg,
|
||||
(cmp_tensor, assert_msg),
|
||||
{},
|
||||
self._create_dummy_node_metadata(),
|
||||
)
|
||||
|
||||
def call_operator(self, op, args, kwargs, meta) -> ProxyValue:
|
||||
ret = super().call_operator(op, args, kwargs, meta)
|
||||
if "val" not in meta:
|
||||
return ret
|
||||
|
||||
val = meta["val"]
|
||||
|
||||
# In general, we may have to deal the case such as: ret[1].shape[0].
|
||||
# We need first find out what symbols require assertion, then we need to follow the path
|
||||
# from ret to the symbol, construct the proxies along the way and construct the messages
|
||||
# piece-wise at the same time.
|
||||
#
|
||||
# We use post-order traversal to collect all the proxies callbacks needed, construct
|
||||
# the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
|
||||
# We need the callbacks because, in order to call the function to create a proxy for shape[0], we
|
||||
# need the proxy for shape, which further requries the proxy for ret[1], etc.
|
||||
def add_assertions(val):
|
||||
call_backs = []
|
||||
messages = []
|
||||
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
||||
symbol = val.node._expr
|
||||
if isinstance(symbol, sympy.Symbol) and symbol.name.startswith("i"):
|
||||
# We only care about unbacked symints for these inline
|
||||
# constraints, which are prefixed with 'i'
|
||||
constraint = self.range_constraints[symbol]
|
||||
min_val, max_val = _convert_range_to_int(constraint)
|
||||
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
|
||||
call_backs.append(
|
||||
partial(self._assert_range_constraint, lower=min_val, upper=max_val)
|
||||
)
|
||||
messages.append(assert_msg)
|
||||
elif isinstance(val, torch.Tensor):
|
||||
for i, sym in enumerate(val.shape):
|
||||
cbs, msgs = add_assertions(sym)
|
||||
for cb, msg in zip(cbs, msgs):
|
||||
def sym_size_cb(proxy, assert_msg, dim):
|
||||
dim_proxy = super(
|
||||
_AddRuntimeAssertionsForInlineConstraintsPass,
|
||||
self
|
||||
).call_operator(
|
||||
torch.ops.aten.sym_size.int,
|
||||
(proxy, dim),
|
||||
{},
|
||||
self._create_dummy_node_metadata(),
|
||||
)
|
||||
cb(proxy=dim_proxy, assert_msg=assert_msg)
|
||||
call_backs.append(partial(sym_size_cb, dim=i))
|
||||
messages.append(f".shape[{i}]" + msg)
|
||||
return call_backs, messages
|
||||
callbacks, messages = add_assertions(val)
|
||||
for cb, msg in zip(callbacks, messages):
|
||||
cb(proxy=ret, assert_msg=f"{ret.node}" + msg)
|
||||
return ret
|
||||
|
||||
def call(self, graph_module):
|
||||
# Add runtime asserts for inline constraints
|
||||
val = super().call(graph_module)
|
||||
|
||||
# Populate the stack trace with dummy vals to respect IR
|
||||
for node in val.graph_module.graph.nodes:
|
||||
if not hasattr(node.meta, "stack_trace"):
|
||||
node.meta["stack_trace"] = traceback.format_exc(-1)
|
||||
|
||||
return PassResult(val.graph_module, val.modified)
|
||||
|
||||
|
||||
class _AddRuntimeAssertionsForConstraintsPass(_AddRuntimeAssertionsForInlineConstraintsPass):
|
||||
def __init__(
|
||||
self,
|
||||
range_constraints: Dict[sympy.Symbol, RangeConstraint],
|
||||
equality_constraints: List[Tuple[InputDim, InputDim]],
|
||||
):
|
||||
super().__init__(range_constraints, equality_constraints)
|
||||
|
||||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
||||
graph_module = copy.deepcopy(graph_module)
|
||||
graph = graph_module.graph
|
||||
|
|
@ -98,11 +193,10 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
|
|||
if isinstance(shape, SymInt):
|
||||
# If the shape is dynamic, add range assertions
|
||||
symbol = shape.node._expr
|
||||
assert symbol in self.range_constraints
|
||||
|
||||
self._insert_range_assert_inplace(
|
||||
graph, input_dim, dim_node, self.range_constraints[symbol]
|
||||
)
|
||||
if symbol in self.range_constraints:
|
||||
self._insert_range_assert_inplace(
|
||||
graph, input_dim, dim_node, self.range_constraints[symbol]
|
||||
)
|
||||
else:
|
||||
# If no dynamism is specified, we assume all dimensions #
|
||||
# are specialized
|
||||
|
|
@ -112,20 +206,13 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
|
|||
)
|
||||
|
||||
# Add runtime assertions on equality constraints on the inputs
|
||||
with graph.inserting_after(
|
||||
list(inputdim_to_node.values())[-1]
|
||||
):
|
||||
self._insert_equality_assert_inplace(graph, inputdim_to_node)
|
||||
if len(inputdim_to_node) > 0:
|
||||
with graph.inserting_after(
|
||||
list(inputdim_to_node.values())[-1]
|
||||
):
|
||||
self._insert_equality_assert_inplace(graph, inputdim_to_node)
|
||||
|
||||
# Add runtime asserts for inline constraints
|
||||
val = super().call(graph_module)
|
||||
|
||||
# Populate the stack trace with dummy vals to respect IR
|
||||
for node in val.graph_module.graph.nodes:
|
||||
if not hasattr(node.meta, "stack_trace"):
|
||||
node.meta["stack_trace"] = traceback.format_exc(-1)
|
||||
|
||||
return PassResult(val.graph_module, val.modified)
|
||||
return super().call(graph_module)
|
||||
|
||||
def _insert_specialized_shape_assert_inplace(
|
||||
self, graph: torch.fx.Graph, input_dim: InputDim, dim_node: torch.fx.Node, shape: int,
|
||||
|
|
@ -202,75 +289,3 @@ class _AddRuntimeAssertionsForConstraintsPass(ExportPassBase):
|
|||
_ = graph.call_function(
|
||||
torch.ops.aten._assert_async.msg, (cmp_tensor_node, assert_msg)
|
||||
)
|
||||
|
||||
def call_operator(self, op, args, kwargs, meta) -> ProxyValue:
|
||||
ret = super().call_operator(op, args, kwargs, meta)
|
||||
if "val" not in meta:
|
||||
return ret
|
||||
|
||||
val = meta["val"]
|
||||
|
||||
# In general, we may have to deal the case such as: ret[1].shape[0].
|
||||
# We need first find out what symbols require assertion, then we need to follow the path
|
||||
# from ret to the symbol, construct the proxies along the way and construct the messages
|
||||
# piece-wise at the same time.
|
||||
#
|
||||
# We use post-order traversal to collect all the proxies callbacks needed, construct
|
||||
# the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
|
||||
# We need the callbacks because, in order to call the function to create a proxy for shape[0], we
|
||||
# need the proxy for shape, which further requries the proxy for ret[1], etc.
|
||||
def add_assertions(val):
|
||||
call_backs = []
|
||||
messages = []
|
||||
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
||||
symbol = val.node._expr
|
||||
if isinstance(symbol, sympy.Symbol) and symbol.name.startswith("i"):
|
||||
# We only care about unbacked symints for these inline
|
||||
# constraints, which are prefixed with 'i'
|
||||
constraint = self.range_constraints[symbol]
|
||||
min_val, max_val = _convert_range_to_int(constraint)
|
||||
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
|
||||
call_backs.append(
|
||||
partial(self._assert_range_constraint, lower=min_val, upper=max_val)
|
||||
)
|
||||
messages.append(assert_msg)
|
||||
elif isinstance(val, torch.Tensor):
|
||||
for i, sym in enumerate(val.shape):
|
||||
cbs, msgs = add_assertions(sym)
|
||||
for cb, msg in zip(cbs, msgs):
|
||||
def sym_size_cb(proxy, assert_msg, dim):
|
||||
dim_proxy = super(_AddRuntimeAssertionsForConstraintsPass, self).call_operator(
|
||||
torch.ops.aten.sym_size.int,
|
||||
(proxy, dim),
|
||||
{},
|
||||
self._create_dummy_node_metadata(),
|
||||
)
|
||||
cb(proxy=dim_proxy, assert_msg=assert_msg)
|
||||
call_backs.append(partial(sym_size_cb, dim=i))
|
||||
messages.append(f".shape[{i}]" + msg)
|
||||
return call_backs, messages
|
||||
callbacks, messages = add_assertions(val)
|
||||
for cb, msg in zip(callbacks, messages):
|
||||
cb(proxy=ret, assert_msg=f"{ret.node}" + msg)
|
||||
return ret
|
||||
|
||||
def _assert_range_constraint(self, proxy, lower, upper, assert_msg):
|
||||
if lower > -math.inf:
|
||||
self._insert_assert_async(operator.ge, proxy, lower, assert_msg)
|
||||
|
||||
if upper < math.inf:
|
||||
self._insert_assert_async(operator.le, proxy, upper, assert_msg)
|
||||
|
||||
def _insert_assert_async(self, operator, lower, upper, assert_msg):
|
||||
"""
|
||||
Inserts assert_async call_function nodes in the graph. This function is
|
||||
called **during** the interpreter-based pass.
|
||||
"""
|
||||
cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata())
|
||||
cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata())
|
||||
super().call_operator(
|
||||
torch.ops.aten._assert_async.msg,
|
||||
(cmp_tensor, assert_msg),
|
||||
{},
|
||||
self._create_dummy_node_metadata(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ class GraphModuleOpUpgrader:
|
|||
upgraded_program = exported_program.transform(_pass)
|
||||
# NB: we have to retrace the graph_module instead of ep because of some failure. Also, we need to turn of
|
||||
# _add_runtime_assertions because dynamo is not happy with sym_size.int.
|
||||
exported_program = export(upgraded_program.graph_module, inputs, {}, _add_runtime_assertions=False)
|
||||
exported_program = export(upgraded_program.graph_module, inputs, {})
|
||||
exported_program.call_spec = upgraded_program.call_spec
|
||||
|
||||
return exported_program
|
||||
|
|
|
|||
Loading…
Reference in a new issue