mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Support negative indices and fix bound checking in symbolic shape inference for Slice (#7401)
* Use positivity everywhere; handle negative index in Slice * limit positivity to inputs * make handle_negative_index private * strengthen sympy comparison * further strengthen compariso n and a minor refactoring * Add flip test * Fall through if -int_max in handle_negative_index() * minor fix for infer_Concat to include initializers * Add more tests * use simplify * more tests
This commit is contained in:
parent
8e3cdf0452
commit
d1cb8c9dc9
2 changed files with 133 additions and 12 deletions
|
|
@ -138,6 +138,7 @@ class SymbolicShapeInference:
|
|||
'Unsqueeze': self._infer_Unsqueeze,
|
||||
'Where': self._infer_symbolic_compute_ops,
|
||||
'ZipMap': self._infer_ZipMap,
|
||||
'Neg': self._infer_symbolic_compute_ops,
|
||||
# contrib ops:
|
||||
'Attention': self._infer_Attention,
|
||||
'BiasGelu': self._infer_BiasGelu,
|
||||
|
|
@ -292,7 +293,7 @@ class SymbolicShapeInference:
|
|||
for d in self._get_shape(node, idx):
|
||||
if type(d) == str:
|
||||
sympy_shape.append(self.symbolic_dims_[d] if d in
|
||||
self.symbolic_dims_ else sympy.Symbol(d, integer=True))
|
||||
self.symbolic_dims_ else sympy.Symbol(d, integer=True, nonnegative=True))
|
||||
else:
|
||||
assert None != d
|
||||
sympy_shape.append(d)
|
||||
|
|
@ -451,7 +452,7 @@ class SymbolicShapeInference:
|
|||
v = self.suggested_merge_[new_dim]
|
||||
new_dim = sympy.Integer(int(v)) if is_literal(v) else v
|
||||
else:
|
||||
self.symbolic_dims_[new_dim] = sympy.Symbol(new_dim, integer=True)
|
||||
self.symbolic_dims_[new_dim] = sympy.Symbol(new_dim, integer=True, nonnegative=True)
|
||||
return new_dim
|
||||
|
||||
def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
|
||||
|
|
@ -587,7 +588,9 @@ class SymbolicShapeInference:
|
|||
'Sub':
|
||||
lambda l: l[0] - l[1],
|
||||
'Where':
|
||||
lambda l: l[1] if l[0] else l[2]
|
||||
lambda l: l[1] if l[0] else l[2],
|
||||
'Neg':
|
||||
lambda l: -l[0]
|
||||
}
|
||||
assert node.op_type in funcs
|
||||
self._compute_on_sympy_data(node, funcs[node.op_type])
|
||||
|
|
@ -621,7 +624,7 @@ class SymbolicShapeInference:
|
|||
output_shape))
|
||||
|
||||
def _infer_Concat(self, node):
|
||||
if any([i in self.sympy_data_ for i in node.input]):
|
||||
if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]):
|
||||
values = self._get_int_values(node)
|
||||
if all([v is not None for v in values]):
|
||||
assert 0 == get_attribute(node, 'axis')
|
||||
|
|
@ -1029,6 +1032,37 @@ class SymbolicShapeInference:
|
|||
helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []))
|
||||
|
||||
def _infer_Slice(self, node):
|
||||
def less_equal(x, y):
|
||||
try:
|
||||
return bool(x <= y)
|
||||
except TypeError:
|
||||
pass
|
||||
try:
|
||||
return bool(y >= x)
|
||||
except TypeError:
|
||||
pass
|
||||
try:
|
||||
return bool(-x >= -y)
|
||||
except TypeError:
|
||||
pass
|
||||
try:
|
||||
return bool(-y <= -x)
|
||||
except TypeError:
|
||||
# the last attempt; this may raise TypeError
|
||||
return bool(y - x >= 0)
|
||||
|
||||
def handle_negative_index(index, bound):
|
||||
""" normalizes a negative index to be in [0, bound) """
|
||||
try:
|
||||
if not less_equal(0, index):
|
||||
if is_literal(index) and index <= -self.int_max_:
|
||||
# this case is handled separately
|
||||
return index
|
||||
return bound + index
|
||||
except TypeError:
|
||||
print("Cannot determine if {} < 0".format(index))
|
||||
return index
|
||||
|
||||
if get_opset(self.out_mp_) <= 9:
|
||||
axes = get_attribute(node, 'axes')
|
||||
starts = get_attribute(node, 'starts')
|
||||
|
|
@ -1059,6 +1093,7 @@ class SymbolicShapeInference:
|
|||
new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
|
||||
else:
|
||||
for i, s, e, t in zip(axes, starts, ends, steps):
|
||||
e = handle_negative_index(e, new_sympy_shape[i])
|
||||
if is_literal(e):
|
||||
if e >= self.int_max_:
|
||||
e = new_sympy_shape[i]
|
||||
|
|
@ -1072,25 +1107,22 @@ class SymbolicShapeInference:
|
|||
if e > 0:
|
||||
e = sympy.Min(e, new_sympy_shape[i]
|
||||
) if e > 1 else e #special case for slicing first to make computation easier
|
||||
else:
|
||||
e = new_sympy_shape[i] + e
|
||||
else:
|
||||
if is_literal(new_sympy_shape[i]):
|
||||
e = sympy.Min(e, new_sympy_shape[i])
|
||||
else:
|
||||
try:
|
||||
if (e - new_sympy_shape[i]) >= 0:
|
||||
if not less_equal(e, new_sympy_shape[i]):
|
||||
e = new_sympy_shape[i]
|
||||
except Exception:
|
||||
print('Unable to determine if {} <= {}, treat as equal'.format(e, new_sympy_shape[i]))
|
||||
e = new_sympy_shape[i]
|
||||
|
||||
if is_literal(s) and int(s) < 0:
|
||||
s = new_sympy_shape[i] + s
|
||||
s = handle_negative_index(s, new_sympy_shape[i])
|
||||
if is_literal(new_sympy_shape[i]) and is_literal(s):
|
||||
s = max(0, min(s, new_sympy_shape[i]))
|
||||
|
||||
new_sympy_shape[i] = (e - s + t + (-1 if t > 0 else 1)) // t
|
||||
new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t)
|
||||
|
||||
self._update_computed_dims(new_sympy_shape)
|
||||
|
||||
|
|
@ -1302,8 +1334,8 @@ class SymbolicShapeInference:
|
|||
assert s_merge in self.symbolic_dims_
|
||||
self.symbolic_dims_[s] = self.symbolic_dims_[s_merge]
|
||||
else:
|
||||
self.symbolic_dims_[s] = sympy.Symbol(s, integer=True)
|
||||
|
||||
# Since inputs are not produced by other ops, we can assume positivity
|
||||
self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
|
||||
# create a temporary ModelProto for single node inference
|
||||
# note that we remove initializer to have faster inference
|
||||
# for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways
|
||||
|
|
|
|||
|
|
@ -3,11 +3,16 @@
|
|||
|
||||
# -*- coding: UTF-8 -*-
|
||||
import onnx
|
||||
from onnx import AttributeProto, TensorProto, GraphProto
|
||||
import os
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
def unique_element(lst):
|
||||
assert len(lst) == 1
|
||||
return lst[0]
|
||||
|
||||
class TestSymbolicShapeInference(unittest.TestCase):
|
||||
def test_symbolic_shape_infer(self):
|
||||
cwd = os.getcwd()
|
||||
|
|
@ -22,5 +27,89 @@ class TestSymbolicShapeInference(unittest.TestCase):
|
|||
int_max=100000,
|
||||
guess_output_rank=True)
|
||||
|
||||
class TestSymbolicShapeInferenceForSlice(unittest.TestCase):
|
||||
def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):
|
||||
_dimstrmap = {dim: f"dim{i}" for i, dim in enumerate(input_dims)}
|
||||
def dimstrmap(dim):
|
||||
return _dimstrmap.get(dim, dim)
|
||||
def get_initializer(name):
|
||||
valuemap = {"zero": 0, "one": 1, "two": 2, "ten": 10, "intmax": 2**32}
|
||||
value = -valuemap[name[4:]] if name.startswith("neg_") else valuemap[name]
|
||||
return onnx.helper.make_tensor(name, TensorProto.INT64, [1], [value])
|
||||
|
||||
initializers = [get_initializer(name) for name in ["zero", "one", "two", "ten", "intmax", "neg_intmax", "neg_one", "neg_ten"]]
|
||||
inputs = []
|
||||
nodes = []
|
||||
for i, dim in enumerate(input_dims):
|
||||
inputs.append(onnx.helper.make_tensor_value_info(f"t{i}", TensorProto.FLOAT, ["B", dim]))
|
||||
nodes.extend(
|
||||
[
|
||||
onnx.helper.make_node("Shape", [f"t{i}"], [f"shape{i}"]),
|
||||
onnx.helper.make_node(
|
||||
"Slice",
|
||||
[f"shape{i}", "one", "two", "zero", "one"],
|
||||
[f"dim{i}"]
|
||||
),
|
||||
onnx.helper.make_node("Neg", [f"dim{i}"], [f"neg_dim{i}"])
|
||||
]
|
||||
)
|
||||
|
||||
def make_concat_dims(concat_name, dims):
|
||||
dims = [
|
||||
f"neg_{dimstrmap(dim[1:])}" if dim.startswith("-") else dimstrmap(dim) for dim in dims
|
||||
]
|
||||
return onnx.helper.make_node("Concat", dims, [concat_name], axis=0)
|
||||
|
||||
nodes.extend(
|
||||
[
|
||||
onnx.helper.make_node("Concat", [inp.name for inp in inputs], ["concat"], axis=1),
|
||||
make_concat_dims("starts", ["zero", start]),
|
||||
make_concat_dims("ends", ["intmax", end]),
|
||||
make_concat_dims("axes", ["zero", "one"]),
|
||||
make_concat_dims("steps", ["one", step]),
|
||||
onnx.helper.make_node("Slice", ["concat", "starts", "ends", "axes", "steps"], ["output"])
|
||||
]
|
||||
)
|
||||
output = onnx.helper.make_tensor_value_info("output", TensorProto.FLOAT, ["d1", "d2"])
|
||||
graph_def = onnx.helper.make_graph(
|
||||
nodes,
|
||||
"graph",
|
||||
inputs,
|
||||
[output],
|
||||
initializer=initializers
|
||||
)
|
||||
model = SymbolicShapeInference.infer_shapes(onnx.helper.make_model(graph_def))
|
||||
output = unique_element(model.graph.output)
|
||||
shape = [d.dim_param if d.dim_param else d.dim_value for d in output.type.tensor_type.shape.dim]
|
||||
self.assertEqual(shape, ["B", expected_output_dim])
|
||||
|
||||
def test_numeric_negative_indices_forward(self):
|
||||
self.check_slice_of_concat(["M"], "-ten", "-one", "one", 9)
|
||||
|
||||
def test_numeric_negative_indices_backward(self):
|
||||
self.check_slice_of_concat(["M"], "-one", "-ten", "-one", 9)
|
||||
|
||||
def test_symbolic_end_index(self):
|
||||
self.check_slice_of_concat(["M", "N"], "zero", "M", "one", "M")
|
||||
|
||||
def test_symbolic_negative_start_index(self):
|
||||
self.check_slice_of_concat(["M", "N"], "-N", "intmax", "one", "N")
|
||||
|
||||
def test_non_unit_step(self):
|
||||
self.check_slice_of_concat(["N", "N"], "zero", "intmax", "two", "N")
|
||||
|
||||
def test_symbolic_step(self):
|
||||
self.check_slice_of_concat(["N", "N"], "zero", "intmax", "N", "floor(-1/N) + 3")
|
||||
|
||||
def test_symbolic_negative_step(self):
|
||||
self.check_slice_of_concat(["N", "N"], "-one", "-intmax", "-N", "floor(-1/N) + 3")
|
||||
|
||||
def test_flip(self):
|
||||
self.check_slice_of_concat(["N"], "-one", "-intmax", "-one", "N")
|
||||
|
||||
def test_flip_of_concat(self):
|
||||
self.check_slice_of_concat(["N", "N", "N"], "-one", "-intmax", "-one", "3*N")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue