[export][reland] Remove runtime assertion pass (#115597)

Summary:
Reland of https://github.com/pytorch/pytorch/pull/115196
D52054112 to fix internal failures.

Test Plan: CI

Differential Revision: D52054110

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115597
Approved by: https://github.com/ydwu4, https://github.com/zhxchen17
This commit is contained in:
Angela Yi 2023-12-15 03:21:59 +00:00 committed by PyTorch MergeBot
parent 7d4ccd7b9e
commit 8e2d63cbc3
7 changed files with 107 additions and 218 deletions

View file

@ -2399,7 +2399,7 @@ def forward(self, x):
example_inputs = (copy(x), y)
ep = torch._export._export(foo, example_inputs, constraints=constraints)
with self.assertRaisesRegex(RuntimeError, "Input.*shape.*specialized at 2"):
with self.assertRaisesRegex(RuntimeError, "input.*shape.*to be equal to 2"):
ep(torch.randn(3), y)
dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y")

View file

@ -245,7 +245,7 @@ class TestExport(TestCase):
torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes)
em = torch.export.export(m, (a,))
x = torch.randn(3, 5)
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
with self.assertRaisesRegex(RuntimeError, "shape\[1\] to be equal to 4, but got 5"):
em(x)
def test_not_correct_dim(self):
@ -1206,13 +1206,13 @@ class TestExport(TestCase):
torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5))
)
with self.assertRaisesRegex(
RuntimeError, "is specialized to be 5 at tracing time"
RuntimeError, "Expected input arg1 to be equal to 5, but got 6"
):
_ = exported(torch.ones(8, 5), 6)
exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes)
with self.assertRaisesRegex(
RuntimeError, "is specialized to be 5.0 at tracing time"
RuntimeError, "Expected input arg1 to be equal to 5.0, but got 6.0"
):
_ = exported(torch.ones(7, 5), 6.0)
@ -1225,7 +1225,7 @@ class TestExport(TestCase):
inps = (torch.randn(4, 4), torch.randn(4), "trunc")
exported = export(g, inps)
with self.assertRaisesRegex(RuntimeError, "is specialized to be trunc at"):
with self.assertRaisesRegex(RuntimeError, "to be equal to trunc, but got floor"):
_ = exported(torch.randn(4, 4), torch.randn(4), "floor")
self.assertTrue(torch.allclose(exported(*inps), g(*inps)))
@ -1306,7 +1306,7 @@ class TestExport(TestCase):
dim0_x = torch.export.Dim("dim0_x")
exported = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x}})
reexported = torch.export.export(exported, (inp,))
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 5"):
with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be equal to 5, but got 7"):
reexported(torch.ones(7, 5))
reexported = torch.export.export(exported, (inp,), dynamic_shapes=({0: dim0_x},))
@ -1315,7 +1315,7 @@ class TestExport(TestCase):
# can't retrace with invalid inputs with respect to the original ExportedProgram
dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3)
exported_v2 = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}})
with self.assertRaisesRegex(RuntimeError, "shape\[1\] is specialized at 5"):
with self.assertRaisesRegex(RuntimeError, "Expected input l_x_.shape\[0\] to be >= 3, but got 2"):
torch.export.export(exported_v2, (torch.randn(2, 2),))
@testing.expectedFailureSerDer
@ -1453,7 +1453,7 @@ class TestExport(TestCase):
self.assertEqual(len(ep.state_dict), 1)
self.assertEqual(len(ep.tensor_constants), 2)
inp = (torch.randn(1),)
inp = (torch.tensor(5),)
self.assertTrue(torch.allclose(ep(*inp), Foo()(*inp)))
transform = ep.run_decompositions()
@ -1620,7 +1620,7 @@ def forward(self, l_x_):
self.assertEqual(ep(*test_inp), foo(*test_inp))
ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None))
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"):
with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be equal to 4, but got 7"):
ep_v2(*test_inp)
def test_constant_output(self):
@ -1693,8 +1693,7 @@ def forward(self, l_x_):
test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4))
with self.assertRaisesRegex(
RuntimeError,
"shape\[0\] is outside of specified dynamic range \[3, inf\]"
RuntimeError, "shape\[0\] to be >= 3, but got 2"
):
ep(*test_inp)
@ -1724,10 +1723,10 @@ def forward(self, l_x_):
inp = torch.randn(4, 4)
gm = capture_pre_autograd_graph(Foo(), (inp,), constraints=[dynamic_dim(inp, 0) >= 3])
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "Expected input arg0_1.shape\[0\] to be >= 3, but got 2"):
gm(torch.randn(2, 2))
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "Expected input arg0_1.shape\[0\] to be >= 3, but got 2"):
torch.export.export(gm, (torch.randn(2, 2),))
ep = torch.export.export(gm, (torch.randn(5, 4),), dynamic_shapes=({0: torch.export.Dim("dim", min=3)},))

View file

@ -76,7 +76,7 @@ class TestPasses(TestCase):
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
with self.assertRaisesRegex(RuntimeError, r"Expected input l_x_.shape\[1\] to be <= 6, but got 7"):
ep(torch.zeros(2, 7, 3))
self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)))
@ -99,10 +99,10 @@ class TestPasses(TestCase):
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
)
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
with self.assertRaisesRegex(RuntimeError, r"Expected input l_x_.shape\[1\] to be <= 6, but got 7"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
with self.assertRaisesRegex(RuntimeError, r"Expected input l_y_.shape\[0\] to be >= 3, but got 2"):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
def test_runtime_assert_some_dims_not_specified(self) -> None:
@ -123,12 +123,12 @@ class TestPasses(TestCase):
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
)
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
with self.assertRaisesRegex(RuntimeError, r"Expected input l_x_.shape\[1\] to be <= 6, but got 7"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError, r"shape\[0\] is specialized at 5"
RuntimeError, r"Expected input l_y_.shape\[0\] to be equal to 5, but got 2"
):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
@ -152,12 +152,12 @@ class TestPasses(TestCase):
dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}})
with self.assertRaisesRegex(RuntimeError, r"shape\[1\] is specialized at 2"):
with self.assertRaisesRegex(RuntimeError, r"shape\[1\] to be equal to 2"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError, r"shape\[0\] is specialized at 5"
RuntimeError, r"Expected input l_y_.shape\[0\] to be equal to 5, but got 2"
):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

View file

@ -283,11 +283,11 @@ class TestUnflatten(TestCase):
return a
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
with self.assertRaisesRegex(RuntimeError, "Expected input l_x_.shape\[0\] to be equal to 2, but got 6"):
export_module(torch.randn(6, 6))
unflattened = export_module.module(flat=False)
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
with self.assertRaisesRegex(RuntimeError, "Expected input l_x_.shape\[0\] to be equal to 2, but got 6"):
unflattened(torch.randn(6, 6))
def test_unflatten_with_inplace_compile(self):

View file

@ -1,22 +1,18 @@
import copy
import math
import operator
import traceback
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Set, Tuple
from typing import Callable, Dict, List, NamedTuple, Set, Tuple
import sympy
import torch
import torch.fx
from torch.fx.experimental.symbolic_shapes import SymInt
from torch._export.pass_base import _ExportPassBase, ProxyValue, PassResult
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._sympy.value_ranges import ValueRanges
__all__ = ["_AddRuntimeAssertionsForConstraintsPass", "InputDim"]
__all__ = ["InputDim"]
class InputDim(NamedTuple):
@ -150,163 +146,3 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(_ExportPassBase):
node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
return PassResult(val.graph_module, val.modified)
class _AddRuntimeAssertionsForConstraintsPass(_AddRuntimeAssertionsForInlineConstraintsPass):
def __init__(
self,
range_constraints: Dict[sympy.Symbol, ValueRanges],
equality_constraints: List[Tuple[InputDim, InputDim]],
):
super().__init__(range_constraints, equality_constraints)
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module = copy.deepcopy(graph_module)
graph = graph_module.graph
insert_loc = None
for node in graph.nodes:
if node.op != "placeholder":
continue
insert_loc = node
if insert_loc is None:
return super().call(graph_module)
# Add runtime asserts for input shape constraints. We do this after all
# placeholder nodes so that we can handle both (unary) predicates and
# (binary) relations.
inputdim_to_node: Dict[InputDim, torch.fx.Node] = OrderedDict()
for node in graph.nodes:
if node.op != "placeholder":
continue
if (
"val" not in node.meta or node.meta["val"] is None
):
continue
if not isinstance(node.meta["val"], FakeTensor):
# it has to be a prim value
self._insert_prim_assert_inplace(graph, node, node.meta["val"])
else:
fake_tensor_shape = node.meta["val"].shape
for dim, shape in enumerate(fake_tensor_shape):
with graph.inserting_after(insert_loc):
dim_node = graph.call_function(
torch.ops.aten.sym_size.int, (node, dim)
)
input_dim = InputDim(node.name, dim)
inputdim_to_node[input_dim] = dim_node
insert_loc = dim_node
if isinstance(shape, SymInt):
# If the shape is dynamic, add range assertions
symbol = shape.node._expr
if symbol in self.range_constraints:
self._insert_range_assert_inplace(
graph, input_dim, dim_node, self.range_constraints[symbol]
)
else:
# If no dynamism is specified, we assume all dimensions #
# are specialized
assert isinstance(shape, int)
self._insert_specialized_shape_assert_inplace(
graph, input_dim, dim_node, shape,
)
# Add runtime assertions on equality constraints on the inputs
if len(inputdim_to_node) > 0:
with graph.inserting_after(
list(inputdim_to_node.values())[-1]
):
self._insert_equality_assert_inplace(graph, inputdim_to_node)
return super().call(graph_module)
def _insert_specialized_shape_assert_inplace(
self, graph: torch.fx.Graph, input_dim: InputDim, dim_node: torch.fx.Node, shape: int,
):
assert_msg = f"Input {input_dim.input_name}.shape[{input_dim.dim}] is specialized at {shape}"
with graph.inserting_after(dim_node):
eq_node = graph.call_function(operator.eq, (dim_node, shape))
with graph.inserting_after(eq_node):
tensor_eq_node = graph.call_function(torch.ops.aten.scalar_tensor.default, (eq_node,))
with graph.inserting_after(tensor_eq_node):
_ = graph.call_function(torch.ops.aten._assert_async.msg, (tensor_eq_node, assert_msg))
def _insert_prim_assert_inplace(self, graph, node: torch.fx.Node, value: Any):
assert_msg = (
f"Input {node.name} is specialized to be {value} at tracing time,"
f"it is not supported to pass in a different value at run time."
)
with graph.inserting_after(node):
eq_node = graph.call_function(operator.eq, (node, value))
with graph.inserting_after(eq_node):
tensor_eq_node = graph.call_function(torch.ops.aten.scalar_tensor.default, (eq_node,))
with graph.inserting_after(tensor_eq_node):
_ = graph.call_function(torch.ops.aten._assert_async.msg, (tensor_eq_node, assert_msg))
def _insert_range_assert_inplace(
self, graph: torch.fx.Graph, input_dim: InputDim, dim_node: torch.fx.Node, range: ValueRanges
):
"""
Add runtime asserts for user-specified range constraints for
each placeholder's dynamic dimension.
"""
min_val, max_val = _convert_range_to_int(range)
assert_msg = (
f"Input {input_dim.input_name}.shape[{input_dim.dim}] is "
f"outside of specified dynamic range [{min_val}, {max_val}]"
)
# TODO (tmanlaibaatar) we are making an assumption that graph generated for
# input dim N >=2 generalizes to N < 2. Ideally we should check that:
# 1. if we can generalize to N < 2, not add any assertion saying N >= 2
# 2. If we can't generalize to N < 2, add an assertion saying N >= 2
# Above can be achieved via a separate pass.
with graph.inserting_after(dim_node):
if min_val > 2:
self._insert_assert_async_inplace(
graph, operator.ge, (dim_node, min_val), assert_msg,
)
if max_val < math.inf:
self._insert_assert_async_inplace(
graph, operator.le, (dim_node, max_val), assert_msg,
)
def _insert_equality_assert_inplace(
self,
graph: torch.fx.Graph,
inputdim_to_node: Dict[InputDim, torch.fx.Node],
):
for input_dim, other_input_dim in self.equality_constraints:
dim_node = inputdim_to_node[input_dim]
assert_msg = (
f"Input {input_dim.input_name}.shape[{input_dim.dim}] is "
f"not equal to input {other_input_dim.input_name}.shape[{other_input_dim.dim}]"
)
other_dim_node = inputdim_to_node[other_input_dim]
self._insert_assert_async_inplace(
graph,
operator.eq,
(dim_node, other_dim_node),
assert_msg
)
def _insert_assert_async_inplace(self, graph, operator, args, assert_msg):
"""
Inserts assert_async call_function nodes in the graph. This function is
called before we run the interpreter-based pass and does an inplace
insertion.
"""
cmp_node = graph.call_function(operator, args)
with graph.inserting_after(cmp_node):
cmp_tensor_node = graph.call_function(
torch.ops.aten.scalar_tensor.default, (cmp_node,)
)
with graph.inserting_after(cmp_tensor_node):
_ = graph.call_function(
torch.ops.aten._assert_async.msg, (cmp_tensor_node, assert_msg)
)

View file

@ -1,11 +1,11 @@
import dataclasses
import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
import torch
from torch._export import ExportedProgram
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import (
_register_pytree_node,
Context,
@ -25,37 +25,87 @@ SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
def _check_input_constraints_pre_hook(self, *args, **kwargs):
flat_args, _ = tree_flatten(args)
return _check_input_constraints_for_graph(
self.graph,
range_constraints=self.range_constraints,
equality_constraints=self.equality_constraints,
)(*flat_args)
[node for node in self.graph.nodes if node.op == "placeholder"],
flat_args,
self.range_constraints,
)
def _check_input_constraints_for_graph(
graph: torch.fx.Graph, range_constraints, equality_constraints
input_placeholders: List[torch.fx.Node], args, range_constraints
):
def check(cond, msg):
if not cond:
# TODO(avik): maybe add more context, e.g., graph signature
raise RuntimeError(msg)
import sympy
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
_AddRuntimeAssertionsForConstraintsPass,
_convert_range_to_int,
)
def inner(*args):
# TODO(zhxchen17) Don't generate a runtime graph on the fly.
_assertion_graph = torch.fx.GraphModule({}, torch.fx.Graph())
for p in graph.nodes:
if p.op != "placeholder":
continue
new_p = _assertion_graph.graph.placeholder(p.name)
new_p.meta = p.meta
_assertion_graph.graph.output(())
_assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass(
range_constraints,
equality_constraints,
)(_assertion_graph)
assert _assertion_graph_res is not None
_assertion_graph = _assertion_graph_res.graph_module
_assertion_graph(*args)
check(
len(args) == len(input_placeholders),
"Unexpected number of inputs "
f"(expected {len(input_placeholders)}, got {len(args)})",
)
# NOTE: export already guarantees that the same symbol is used in metadata
# for all InputDims related by equality constraints, so we can just unify
# symbols with given input dimension values to check equality constraints.
unification_map: "Dict[sympy.Symbol, Any]" = {}
for arg, node in zip(args, input_placeholders):
node_val = node.meta["val"]
if isinstance(node_val, FakeTensor):
check(
isinstance(arg, torch.Tensor),
f"Expected input {node.name} to be a tensor, but got {type(arg)}",
)
check(
len(node_val.shape) == len(arg.shape),
f"Unexpected number of dimensions in input {node.name}.shape "
f"(expected {node_val.shape}, got {arg.shape})",
)
for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)):
if isinstance(node_dim, torch.SymInt):
if node_dim.node.expr in unification_map:
existing_dim = unification_map[node_dim.node.expr]
check(
arg_dim == existing_dim,
f"Expected input {node.name}.shape[{j}] to be equal to "
f"{existing_dim}, but got {arg_dim}",
)
else:
unification_map[node_dim.node.expr] = arg_dim
return inner
if node_dim.node.expr in range_constraints:
min_val, max_val = _convert_range_to_int(
range_constraints[node_dim.node.expr]
)
# NOTE: we allow dimensions to be 0/1 at runtime
if min_val > 2:
check(
arg_dim >= min_val,
f"Expected input {node.name}.shape[{j}] to be >= "
f"{min_val}, but got {arg_dim}",
)
if max_val < math.inf:
check(
arg_dim <= max_val,
f"Expected input {node.name}.shape[{j}] to be <= "
f"{max_val}, but got {arg_dim}",
)
else:
check(
arg_dim == node_dim,
f"Expected input {node.name}.shape[{j}] to be equal to "
f"{node_dim}, but got {arg_dim}",
)
elif isinstance(node_val, (int, float, str)):
check(
type(arg) == type(node_val) and arg == node_val,
f"Expected input {node.name} to be equal to {node_val}, but got {arg}",
)
def register_dataclass_as_pytree_node(

View file

@ -272,9 +272,7 @@ class ExportedProgram:
)
else:
ordered_tensor_constants = ()
self._check_input_constraints(
*ordered_params, *ordered_buffers, *ordered_tensor_constants, *args
)
self._check_input_constraints(*args)
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
# See: torch/_functorch/aot_autograd.py#L1034
@ -567,9 +565,15 @@ class ExportedProgram:
def _check_input_constraints(self, *args):
from torch._export.utils import _check_input_constraints_for_graph
placeholders = [p for p in self.graph.nodes if p.op == "placeholder"]
input_placeholders = [
p
for p, s in zip(placeholders, self.graph_signature.input_specs)
if s.kind == InputKind.USER_INPUT
]
_check_input_constraints_for_graph(
self.graph, self.range_constraints, self.equality_constraints
)(*args)
input_placeholders, args, self.range_constraints
)
def _validate(self):
self.verifier().check(self)