mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[export][reland] Remove runtime assertion pass (#115597)
Summary: Reland of https://github.com/pytorch/pytorch/pull/115196 D52054112 to fix internal failures. Test Plan: CI Differential Revision: D52054110 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115597 Approved by: https://github.com/ydwu4, https://github.com/zhxchen17
This commit is contained in:
parent
7d4ccd7b9e
commit
8e2d63cbc3
7 changed files with 107 additions and 218 deletions
|
|
@ -2399,7 +2399,7 @@ def forward(self, x):
|
|||
|
||||
example_inputs = (copy(x), y)
|
||||
ep = torch._export._export(foo, example_inputs, constraints=constraints)
|
||||
with self.assertRaisesRegex(RuntimeError, "Input.*shape.*specialized at 2"):
|
||||
with self.assertRaisesRegex(RuntimeError, "input.*shape.*to be equal to 2"):
|
||||
ep(torch.randn(3), y)
|
||||
|
||||
dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y")
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ class TestExport(TestCase):
|
|||
torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes)
|
||||
em = torch.export.export(m, (a,))
|
||||
x = torch.randn(3, 5)
|
||||
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[1\] to be equal to 4, but got 5"):
|
||||
em(x)
|
||||
|
||||
def test_not_correct_dim(self):
|
||||
|
|
@ -1206,13 +1206,13 @@ class TestExport(TestCase):
|
|||
torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5))
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "is specialized to be 5 at tracing time"
|
||||
RuntimeError, "Expected input arg1 to be equal to 5, but got 6"
|
||||
):
|
||||
_ = exported(torch.ones(8, 5), 6)
|
||||
|
||||
exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "is specialized to be 5.0 at tracing time"
|
||||
RuntimeError, "Expected input arg1 to be equal to 5.0, but got 6.0"
|
||||
):
|
||||
_ = exported(torch.ones(7, 5), 6.0)
|
||||
|
||||
|
|
@ -1225,7 +1225,7 @@ class TestExport(TestCase):
|
|||
|
||||
inps = (torch.randn(4, 4), torch.randn(4), "trunc")
|
||||
exported = export(g, inps)
|
||||
with self.assertRaisesRegex(RuntimeError, "is specialized to be trunc at"):
|
||||
with self.assertRaisesRegex(RuntimeError, "to be equal to trunc, but got floor"):
|
||||
_ = exported(torch.randn(4, 4), torch.randn(4), "floor")
|
||||
self.assertTrue(torch.allclose(exported(*inps), g(*inps)))
|
||||
|
||||
|
|
@ -1306,7 +1306,7 @@ class TestExport(TestCase):
|
|||
dim0_x = torch.export.Dim("dim0_x")
|
||||
exported = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x}})
|
||||
reexported = torch.export.export(exported, (inp,))
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 5"):
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be equal to 5, but got 7"):
|
||||
reexported(torch.ones(7, 5))
|
||||
|
||||
reexported = torch.export.export(exported, (inp,), dynamic_shapes=({0: dim0_x},))
|
||||
|
|
@ -1315,7 +1315,7 @@ class TestExport(TestCase):
|
|||
# can't retrace with invalid inputs with respect to the original ExportedProgram
|
||||
dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3)
|
||||
exported_v2 = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}})
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[1\] is specialized at 5"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected input l_x_.shape\[0\] to be >= 3, but got 2"):
|
||||
torch.export.export(exported_v2, (torch.randn(2, 2),))
|
||||
|
||||
@testing.expectedFailureSerDer
|
||||
|
|
@ -1453,7 +1453,7 @@ class TestExport(TestCase):
|
|||
self.assertEqual(len(ep.state_dict), 1)
|
||||
self.assertEqual(len(ep.tensor_constants), 2)
|
||||
|
||||
inp = (torch.randn(1),)
|
||||
inp = (torch.tensor(5),)
|
||||
self.assertTrue(torch.allclose(ep(*inp), Foo()(*inp)))
|
||||
|
||||
transform = ep.run_decompositions()
|
||||
|
|
@ -1620,7 +1620,7 @@ def forward(self, l_x_):
|
|||
self.assertEqual(ep(*test_inp), foo(*test_inp))
|
||||
|
||||
ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None))
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"):
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be equal to 4, but got 7"):
|
||||
ep_v2(*test_inp)
|
||||
|
||||
def test_constant_output(self):
|
||||
|
|
@ -1693,8 +1693,7 @@ def forward(self, l_x_):
|
|||
|
||||
test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"shape\[0\] is outside of specified dynamic range \[3, inf\]"
|
||||
RuntimeError, "shape\[0\] to be >= 3, but got 2"
|
||||
):
|
||||
ep(*test_inp)
|
||||
|
||||
|
|
@ -1724,10 +1723,10 @@ def forward(self, l_x_):
|
|||
inp = torch.randn(4, 4)
|
||||
gm = capture_pre_autograd_graph(Foo(), (inp,), constraints=[dynamic_dim(inp, 0) >= 3])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected input arg0_1.shape\[0\] to be >= 3, but got 2"):
|
||||
gm(torch.randn(2, 2))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected input arg0_1.shape\[0\] to be >= 3, but got 2"):
|
||||
torch.export.export(gm, (torch.randn(2, 2),))
|
||||
|
||||
ep = torch.export.export(gm, (torch.randn(5, 4),), dynamic_shapes=({0: torch.export.Dim("dim", min=3)},))
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class TestPasses(TestCase):
|
|||
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
|
||||
ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected input l_x_.shape\[1\] to be <= 6, but got 7"):
|
||||
ep(torch.zeros(2, 7, 3))
|
||||
|
||||
self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)))
|
||||
|
|
@ -99,10 +99,10 @@ class TestPasses(TestCase):
|
|||
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected input l_x_.shape\[1\] to be <= 6, but got 7"):
|
||||
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected input l_y_.shape\[0\] to be >= 3, but got 2"):
|
||||
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
|
||||
|
||||
def test_runtime_assert_some_dims_not_specified(self) -> None:
|
||||
|
|
@ -123,12 +123,12 @@ class TestPasses(TestCase):
|
|||
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected input l_x_.shape\[1\] to be <= 6, but got 7"):
|
||||
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
# y is specialized to 5
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"shape\[0\] is specialized at 5"
|
||||
RuntimeError, r"Expected input l_y_.shape\[0\] to be equal to 5, but got 2"
|
||||
):
|
||||
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
|
||||
|
||||
|
|
@ -152,12 +152,12 @@ class TestPasses(TestCase):
|
|||
dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
|
||||
ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}})
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"shape\[1\] is specialized at 2"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"shape\[1\] to be equal to 2"):
|
||||
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
# y is specialized to 5
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"shape\[0\] is specialized at 5"
|
||||
RuntimeError, r"Expected input l_y_.shape\[0\] to be equal to 5, but got 2"
|
||||
):
|
||||
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
|
||||
|
||||
|
|
|
|||
|
|
@ -283,11 +283,11 @@ class TestUnflatten(TestCase):
|
|||
return a
|
||||
|
||||
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
|
||||
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected input l_x_.shape\[0\] to be equal to 2, but got 6"):
|
||||
export_module(torch.randn(6, 6))
|
||||
|
||||
unflattened = export_module.module(flat=False)
|
||||
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected input l_x_.shape\[0\] to be equal to 2, but got 6"):
|
||||
unflattened(torch.randn(6, 6))
|
||||
|
||||
def test_unflatten_with_inplace_compile(self):
|
||||
|
|
|
|||
|
|
@ -1,22 +1,18 @@
|
|||
import copy
|
||||
import math
|
||||
import operator
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Set, Tuple
|
||||
from typing import Callable, Dict, List, NamedTuple, Set, Tuple
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.experimental.symbolic_shapes import SymInt
|
||||
from torch._export.pass_base import _ExportPassBase, ProxyValue, PassResult
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
|
||||
__all__ = ["_AddRuntimeAssertionsForConstraintsPass", "InputDim"]
|
||||
__all__ = ["InputDim"]
|
||||
|
||||
|
||||
class InputDim(NamedTuple):
|
||||
|
|
@ -150,163 +146,3 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(_ExportPassBase):
|
|||
node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
|
||||
|
||||
return PassResult(val.graph_module, val.modified)
|
||||
|
||||
|
||||
class _AddRuntimeAssertionsForConstraintsPass(_AddRuntimeAssertionsForInlineConstraintsPass):
|
||||
def __init__(
|
||||
self,
|
||||
range_constraints: Dict[sympy.Symbol, ValueRanges],
|
||||
equality_constraints: List[Tuple[InputDim, InputDim]],
|
||||
):
|
||||
super().__init__(range_constraints, equality_constraints)
|
||||
|
||||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
||||
graph_module = copy.deepcopy(graph_module)
|
||||
graph = graph_module.graph
|
||||
|
||||
insert_loc = None
|
||||
for node in graph.nodes:
|
||||
if node.op != "placeholder":
|
||||
continue
|
||||
insert_loc = node
|
||||
if insert_loc is None:
|
||||
return super().call(graph_module)
|
||||
|
||||
# Add runtime asserts for input shape constraints. We do this after all
|
||||
# placeholder nodes so that we can handle both (unary) predicates and
|
||||
# (binary) relations.
|
||||
inputdim_to_node: Dict[InputDim, torch.fx.Node] = OrderedDict()
|
||||
for node in graph.nodes:
|
||||
if node.op != "placeholder":
|
||||
continue
|
||||
|
||||
if (
|
||||
"val" not in node.meta or node.meta["val"] is None
|
||||
):
|
||||
continue
|
||||
|
||||
if not isinstance(node.meta["val"], FakeTensor):
|
||||
# it has to be a prim value
|
||||
self._insert_prim_assert_inplace(graph, node, node.meta["val"])
|
||||
else:
|
||||
fake_tensor_shape = node.meta["val"].shape
|
||||
for dim, shape in enumerate(fake_tensor_shape):
|
||||
with graph.inserting_after(insert_loc):
|
||||
dim_node = graph.call_function(
|
||||
torch.ops.aten.sym_size.int, (node, dim)
|
||||
)
|
||||
input_dim = InputDim(node.name, dim)
|
||||
inputdim_to_node[input_dim] = dim_node
|
||||
insert_loc = dim_node
|
||||
|
||||
if isinstance(shape, SymInt):
|
||||
# If the shape is dynamic, add range assertions
|
||||
symbol = shape.node._expr
|
||||
if symbol in self.range_constraints:
|
||||
self._insert_range_assert_inplace(
|
||||
graph, input_dim, dim_node, self.range_constraints[symbol]
|
||||
)
|
||||
else:
|
||||
# If no dynamism is specified, we assume all dimensions #
|
||||
# are specialized
|
||||
assert isinstance(shape, int)
|
||||
self._insert_specialized_shape_assert_inplace(
|
||||
graph, input_dim, dim_node, shape,
|
||||
)
|
||||
|
||||
# Add runtime assertions on equality constraints on the inputs
|
||||
if len(inputdim_to_node) > 0:
|
||||
with graph.inserting_after(
|
||||
list(inputdim_to_node.values())[-1]
|
||||
):
|
||||
self._insert_equality_assert_inplace(graph, inputdim_to_node)
|
||||
|
||||
return super().call(graph_module)
|
||||
|
||||
def _insert_specialized_shape_assert_inplace(
|
||||
self, graph: torch.fx.Graph, input_dim: InputDim, dim_node: torch.fx.Node, shape: int,
|
||||
):
|
||||
assert_msg = f"Input {input_dim.input_name}.shape[{input_dim.dim}] is specialized at {shape}"
|
||||
with graph.inserting_after(dim_node):
|
||||
eq_node = graph.call_function(operator.eq, (dim_node, shape))
|
||||
with graph.inserting_after(eq_node):
|
||||
tensor_eq_node = graph.call_function(torch.ops.aten.scalar_tensor.default, (eq_node,))
|
||||
with graph.inserting_after(tensor_eq_node):
|
||||
_ = graph.call_function(torch.ops.aten._assert_async.msg, (tensor_eq_node, assert_msg))
|
||||
|
||||
def _insert_prim_assert_inplace(self, graph, node: torch.fx.Node, value: Any):
|
||||
assert_msg = (
|
||||
f"Input {node.name} is specialized to be {value} at tracing time,"
|
||||
f"it is not supported to pass in a different value at run time."
|
||||
)
|
||||
with graph.inserting_after(node):
|
||||
eq_node = graph.call_function(operator.eq, (node, value))
|
||||
with graph.inserting_after(eq_node):
|
||||
tensor_eq_node = graph.call_function(torch.ops.aten.scalar_tensor.default, (eq_node,))
|
||||
with graph.inserting_after(tensor_eq_node):
|
||||
_ = graph.call_function(torch.ops.aten._assert_async.msg, (tensor_eq_node, assert_msg))
|
||||
|
||||
def _insert_range_assert_inplace(
|
||||
self, graph: torch.fx.Graph, input_dim: InputDim, dim_node: torch.fx.Node, range: ValueRanges
|
||||
):
|
||||
"""
|
||||
Add runtime asserts for user-specified range constraints for
|
||||
each placeholder's dynamic dimension.
|
||||
"""
|
||||
|
||||
min_val, max_val = _convert_range_to_int(range)
|
||||
assert_msg = (
|
||||
f"Input {input_dim.input_name}.shape[{input_dim.dim}] is "
|
||||
f"outside of specified dynamic range [{min_val}, {max_val}]"
|
||||
)
|
||||
# TODO (tmanlaibaatar) we are making an assumption that graph generated for
|
||||
# input dim N >=2 generalizes to N < 2. Ideally we should check that:
|
||||
# 1. if we can generalize to N < 2, not add any assertion saying N >= 2
|
||||
# 2. If we can't generalize to N < 2, add an assertion saying N >= 2
|
||||
# Above can be achieved via a separate pass.
|
||||
with graph.inserting_after(dim_node):
|
||||
if min_val > 2:
|
||||
self._insert_assert_async_inplace(
|
||||
graph, operator.ge, (dim_node, min_val), assert_msg,
|
||||
)
|
||||
|
||||
if max_val < math.inf:
|
||||
self._insert_assert_async_inplace(
|
||||
graph, operator.le, (dim_node, max_val), assert_msg,
|
||||
)
|
||||
|
||||
def _insert_equality_assert_inplace(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
inputdim_to_node: Dict[InputDim, torch.fx.Node],
|
||||
):
|
||||
for input_dim, other_input_dim in self.equality_constraints:
|
||||
dim_node = inputdim_to_node[input_dim]
|
||||
assert_msg = (
|
||||
f"Input {input_dim.input_name}.shape[{input_dim.dim}] is "
|
||||
f"not equal to input {other_input_dim.input_name}.shape[{other_input_dim.dim}]"
|
||||
)
|
||||
|
||||
other_dim_node = inputdim_to_node[other_input_dim]
|
||||
self._insert_assert_async_inplace(
|
||||
graph,
|
||||
operator.eq,
|
||||
(dim_node, other_dim_node),
|
||||
assert_msg
|
||||
)
|
||||
|
||||
def _insert_assert_async_inplace(self, graph, operator, args, assert_msg):
|
||||
"""
|
||||
Inserts assert_async call_function nodes in the graph. This function is
|
||||
called before we run the interpreter-based pass and does an inplace
|
||||
insertion.
|
||||
"""
|
||||
cmp_node = graph.call_function(operator, args)
|
||||
with graph.inserting_after(cmp_node):
|
||||
cmp_tensor_node = graph.call_function(
|
||||
torch.ops.aten.scalar_tensor.default, (cmp_node,)
|
||||
)
|
||||
with graph.inserting_after(cmp_tensor_node):
|
||||
_ = graph.call_function(
|
||||
torch.ops.aten._assert_async.msg, (cmp_tensor_node, assert_msg)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import dataclasses
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from torch._export import ExportedProgram
|
||||
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.utils._pytree import (
|
||||
_register_pytree_node,
|
||||
Context,
|
||||
|
|
@ -25,37 +25,87 @@ SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
|
|||
def _check_input_constraints_pre_hook(self, *args, **kwargs):
|
||||
flat_args, _ = tree_flatten(args)
|
||||
return _check_input_constraints_for_graph(
|
||||
self.graph,
|
||||
range_constraints=self.range_constraints,
|
||||
equality_constraints=self.equality_constraints,
|
||||
)(*flat_args)
|
||||
[node for node in self.graph.nodes if node.op == "placeholder"],
|
||||
flat_args,
|
||||
self.range_constraints,
|
||||
)
|
||||
|
||||
|
||||
def _check_input_constraints_for_graph(
|
||||
graph: torch.fx.Graph, range_constraints, equality_constraints
|
||||
input_placeholders: List[torch.fx.Node], args, range_constraints
|
||||
):
|
||||
def check(cond, msg):
|
||||
if not cond:
|
||||
# TODO(avik): maybe add more context, e.g., graph signature
|
||||
raise RuntimeError(msg)
|
||||
|
||||
import sympy
|
||||
|
||||
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
|
||||
_AddRuntimeAssertionsForConstraintsPass,
|
||||
_convert_range_to_int,
|
||||
)
|
||||
|
||||
def inner(*args):
|
||||
# TODO(zhxchen17) Don't generate a runtime graph on the fly.
|
||||
_assertion_graph = torch.fx.GraphModule({}, torch.fx.Graph())
|
||||
for p in graph.nodes:
|
||||
if p.op != "placeholder":
|
||||
continue
|
||||
new_p = _assertion_graph.graph.placeholder(p.name)
|
||||
new_p.meta = p.meta
|
||||
_assertion_graph.graph.output(())
|
||||
_assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass(
|
||||
range_constraints,
|
||||
equality_constraints,
|
||||
)(_assertion_graph)
|
||||
assert _assertion_graph_res is not None
|
||||
_assertion_graph = _assertion_graph_res.graph_module
|
||||
_assertion_graph(*args)
|
||||
check(
|
||||
len(args) == len(input_placeholders),
|
||||
"Unexpected number of inputs "
|
||||
f"(expected {len(input_placeholders)}, got {len(args)})",
|
||||
)
|
||||
# NOTE: export already guarantees that the same symbol is used in metadata
|
||||
# for all InputDims related by equality constraints, so we can just unify
|
||||
# symbols with given input dimension values to check equality constraints.
|
||||
unification_map: "Dict[sympy.Symbol, Any]" = {}
|
||||
for arg, node in zip(args, input_placeholders):
|
||||
node_val = node.meta["val"]
|
||||
if isinstance(node_val, FakeTensor):
|
||||
check(
|
||||
isinstance(arg, torch.Tensor),
|
||||
f"Expected input {node.name} to be a tensor, but got {type(arg)}",
|
||||
)
|
||||
check(
|
||||
len(node_val.shape) == len(arg.shape),
|
||||
f"Unexpected number of dimensions in input {node.name}.shape "
|
||||
f"(expected {node_val.shape}, got {arg.shape})",
|
||||
)
|
||||
for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)):
|
||||
if isinstance(node_dim, torch.SymInt):
|
||||
if node_dim.node.expr in unification_map:
|
||||
existing_dim = unification_map[node_dim.node.expr]
|
||||
check(
|
||||
arg_dim == existing_dim,
|
||||
f"Expected input {node.name}.shape[{j}] to be equal to "
|
||||
f"{existing_dim}, but got {arg_dim}",
|
||||
)
|
||||
else:
|
||||
unification_map[node_dim.node.expr] = arg_dim
|
||||
|
||||
return inner
|
||||
if node_dim.node.expr in range_constraints:
|
||||
min_val, max_val = _convert_range_to_int(
|
||||
range_constraints[node_dim.node.expr]
|
||||
)
|
||||
# NOTE: we allow dimensions to be 0/1 at runtime
|
||||
if min_val > 2:
|
||||
check(
|
||||
arg_dim >= min_val,
|
||||
f"Expected input {node.name}.shape[{j}] to be >= "
|
||||
f"{min_val}, but got {arg_dim}",
|
||||
)
|
||||
if max_val < math.inf:
|
||||
check(
|
||||
arg_dim <= max_val,
|
||||
f"Expected input {node.name}.shape[{j}] to be <= "
|
||||
f"{max_val}, but got {arg_dim}",
|
||||
)
|
||||
else:
|
||||
check(
|
||||
arg_dim == node_dim,
|
||||
f"Expected input {node.name}.shape[{j}] to be equal to "
|
||||
f"{node_dim}, but got {arg_dim}",
|
||||
)
|
||||
elif isinstance(node_val, (int, float, str)):
|
||||
check(
|
||||
type(arg) == type(node_val) and arg == node_val,
|
||||
f"Expected input {node.name} to be equal to {node_val}, but got {arg}",
|
||||
)
|
||||
|
||||
|
||||
def register_dataclass_as_pytree_node(
|
||||
|
|
|
|||
|
|
@ -272,9 +272,7 @@ class ExportedProgram:
|
|||
)
|
||||
else:
|
||||
ordered_tensor_constants = ()
|
||||
self._check_input_constraints(
|
||||
*ordered_params, *ordered_buffers, *ordered_tensor_constants, *args
|
||||
)
|
||||
self._check_input_constraints(*args)
|
||||
|
||||
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
|
||||
# See: torch/_functorch/aot_autograd.py#L1034
|
||||
|
|
@ -567,9 +565,15 @@ class ExportedProgram:
|
|||
def _check_input_constraints(self, *args):
|
||||
from torch._export.utils import _check_input_constraints_for_graph
|
||||
|
||||
placeholders = [p for p in self.graph.nodes if p.op == "placeholder"]
|
||||
input_placeholders = [
|
||||
p
|
||||
for p, s in zip(placeholders, self.graph_signature.input_specs)
|
||||
if s.kind == InputKind.USER_INPUT
|
||||
]
|
||||
_check_input_constraints_for_graph(
|
||||
self.graph, self.range_constraints, self.equality_constraints
|
||||
)(*args)
|
||||
input_placeholders, args, self.range_constraints
|
||||
)
|
||||
|
||||
def _validate(self):
|
||||
self.verifier().check(self)
|
||||
|
|
|
|||
Loading…
Reference in a new issue