Support negative indices and fix bound checking in symbolic shape inference for Slice (#7401)

* Use positivity everywhere; handle negative index in Slice

* limit positivity to inputs

* make handle_negative_index private

* strengthen sympy comparison

* further strengthen compariso
n and a minor refactoring

* Add flip test

* Fall through if -int_max in handle_negative_index()

* minor fix for infer_Concat to include initializers

* Add more tests

* use simplify

* more tests
This commit is contained in:
Ryota Tomioka 2021-05-03 17:07:55 +01:00 committed by GitHub
parent 8e3cdf0452
commit d1cb8c9dc9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 12 deletions

View file

@ -138,6 +138,7 @@ class SymbolicShapeInference:
'Unsqueeze': self._infer_Unsqueeze,
'Where': self._infer_symbolic_compute_ops,
'ZipMap': self._infer_ZipMap,
'Neg': self._infer_symbolic_compute_ops,
# contrib ops:
'Attention': self._infer_Attention,
'BiasGelu': self._infer_BiasGelu,
@ -292,7 +293,7 @@ class SymbolicShapeInference:
for d in self._get_shape(node, idx):
if type(d) == str:
sympy_shape.append(self.symbolic_dims_[d] if d in
self.symbolic_dims_ else sympy.Symbol(d, integer=True))
self.symbolic_dims_ else sympy.Symbol(d, integer=True, nonnegative=True))
else:
assert None != d
sympy_shape.append(d)
@ -451,7 +452,7 @@ class SymbolicShapeInference:
v = self.suggested_merge_[new_dim]
new_dim = sympy.Integer(int(v)) if is_literal(v) else v
else:
self.symbolic_dims_[new_dim] = sympy.Symbol(new_dim, integer=True)
self.symbolic_dims_[new_dim] = sympy.Symbol(new_dim, integer=True, nonnegative=True)
return new_dim
def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
@ -587,7 +588,9 @@ class SymbolicShapeInference:
'Sub':
lambda l: l[0] - l[1],
'Where':
lambda l: l[1] if l[0] else l[2]
lambda l: l[1] if l[0] else l[2],
'Neg':
lambda l: -l[0]
}
assert node.op_type in funcs
self._compute_on_sympy_data(node, funcs[node.op_type])
@ -621,7 +624,7 @@ class SymbolicShapeInference:
output_shape))
def _infer_Concat(self, node):
if any([i in self.sympy_data_ for i in node.input]):
if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]):
values = self._get_int_values(node)
if all([v is not None for v in values]):
assert 0 == get_attribute(node, 'axis')
@ -1029,6 +1032,37 @@ class SymbolicShapeInference:
helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []))
def _infer_Slice(self, node):
def less_equal(x, y):
try:
return bool(x <= y)
except TypeError:
pass
try:
return bool(y >= x)
except TypeError:
pass
try:
return bool(-x >= -y)
except TypeError:
pass
try:
return bool(-y <= -x)
except TypeError:
# the last attempt; this may raise TypeError
return bool(y - x >= 0)
def handle_negative_index(index, bound):
""" normalizes a negative index to be in [0, bound) """
try:
if not less_equal(0, index):
if is_literal(index) and index <= -self.int_max_:
# this case is handled separately
return index
return bound + index
except TypeError:
print("Cannot determine if {} < 0".format(index))
return index
if get_opset(self.out_mp_) <= 9:
axes = get_attribute(node, 'axes')
starts = get_attribute(node, 'starts')
@ -1059,6 +1093,7 @@ class SymbolicShapeInference:
new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
else:
for i, s, e, t in zip(axes, starts, ends, steps):
e = handle_negative_index(e, new_sympy_shape[i])
if is_literal(e):
if e >= self.int_max_:
e = new_sympy_shape[i]
@ -1072,25 +1107,22 @@ class SymbolicShapeInference:
if e > 0:
e = sympy.Min(e, new_sympy_shape[i]
) if e > 1 else e #special case for slicing first to make computation easier
else:
e = new_sympy_shape[i] + e
else:
if is_literal(new_sympy_shape[i]):
e = sympy.Min(e, new_sympy_shape[i])
else:
try:
if (e - new_sympy_shape[i]) >= 0:
if not less_equal(e, new_sympy_shape[i]):
e = new_sympy_shape[i]
except Exception:
print('Unable to determine if {} <= {}, treat as equal'.format(e, new_sympy_shape[i]))
e = new_sympy_shape[i]
if is_literal(s) and int(s) < 0:
s = new_sympy_shape[i] + s
s = handle_negative_index(s, new_sympy_shape[i])
if is_literal(new_sympy_shape[i]) and is_literal(s):
s = max(0, min(s, new_sympy_shape[i]))
new_sympy_shape[i] = (e - s + t + (-1 if t > 0 else 1)) // t
new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t)
self._update_computed_dims(new_sympy_shape)
@ -1302,8 +1334,8 @@ class SymbolicShapeInference:
assert s_merge in self.symbolic_dims_
self.symbolic_dims_[s] = self.symbolic_dims_[s_merge]
else:
self.symbolic_dims_[s] = sympy.Symbol(s, integer=True)
# Since inputs are not produced by other ops, we can assume positivity
self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
# create a temporary ModelProto for single node inference
# note that we remove initializer to have faster inference
# for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways

View file

@ -3,11 +3,16 @@
# -*- coding: UTF-8 -*-
import onnx
from onnx import AttributeProto, TensorProto, GraphProto
import os
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from pathlib import Path
import unittest
def unique_element(lst):
assert len(lst) == 1
return lst[0]
class TestSymbolicShapeInference(unittest.TestCase):
def test_symbolic_shape_infer(self):
cwd = os.getcwd()
@ -22,5 +27,89 @@ class TestSymbolicShapeInference(unittest.TestCase):
int_max=100000,
guess_output_rank=True)
class TestSymbolicShapeInferenceForSlice(unittest.TestCase):
def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):
_dimstrmap = {dim: f"dim{i}" for i, dim in enumerate(input_dims)}
def dimstrmap(dim):
return _dimstrmap.get(dim, dim)
def get_initializer(name):
valuemap = {"zero": 0, "one": 1, "two": 2, "ten": 10, "intmax": 2**32}
value = -valuemap[name[4:]] if name.startswith("neg_") else valuemap[name]
return onnx.helper.make_tensor(name, TensorProto.INT64, [1], [value])
initializers = [get_initializer(name) for name in ["zero", "one", "two", "ten", "intmax", "neg_intmax", "neg_one", "neg_ten"]]
inputs = []
nodes = []
for i, dim in enumerate(input_dims):
inputs.append(onnx.helper.make_tensor_value_info(f"t{i}", TensorProto.FLOAT, ["B", dim]))
nodes.extend(
[
onnx.helper.make_node("Shape", [f"t{i}"], [f"shape{i}"]),
onnx.helper.make_node(
"Slice",
[f"shape{i}", "one", "two", "zero", "one"],
[f"dim{i}"]
),
onnx.helper.make_node("Neg", [f"dim{i}"], [f"neg_dim{i}"])
]
)
def make_concat_dims(concat_name, dims):
dims = [
f"neg_{dimstrmap(dim[1:])}" if dim.startswith("-") else dimstrmap(dim) for dim in dims
]
return onnx.helper.make_node("Concat", dims, [concat_name], axis=0)
nodes.extend(
[
onnx.helper.make_node("Concat", [inp.name for inp in inputs], ["concat"], axis=1),
make_concat_dims("starts", ["zero", start]),
make_concat_dims("ends", ["intmax", end]),
make_concat_dims("axes", ["zero", "one"]),
make_concat_dims("steps", ["one", step]),
onnx.helper.make_node("Slice", ["concat", "starts", "ends", "axes", "steps"], ["output"])
]
)
output = onnx.helper.make_tensor_value_info("output", TensorProto.FLOAT, ["d1", "d2"])
graph_def = onnx.helper.make_graph(
nodes,
"graph",
inputs,
[output],
initializer=initializers
)
model = SymbolicShapeInference.infer_shapes(onnx.helper.make_model(graph_def))
output = unique_element(model.graph.output)
shape = [d.dim_param if d.dim_param else d.dim_value for d in output.type.tensor_type.shape.dim]
self.assertEqual(shape, ["B", expected_output_dim])
def test_numeric_negative_indices_forward(self):
self.check_slice_of_concat(["M"], "-ten", "-one", "one", 9)
def test_numeric_negative_indices_backward(self):
self.check_slice_of_concat(["M"], "-one", "-ten", "-one", 9)
def test_symbolic_end_index(self):
self.check_slice_of_concat(["M", "N"], "zero", "M", "one", "M")
def test_symbolic_negative_start_index(self):
self.check_slice_of_concat(["M", "N"], "-N", "intmax", "one", "N")
def test_non_unit_step(self):
self.check_slice_of_concat(["N", "N"], "zero", "intmax", "two", "N")
def test_symbolic_step(self):
self.check_slice_of_concat(["N", "N"], "zero", "intmax", "N", "floor(-1/N) + 3")
def test_symbolic_negative_step(self):
self.check_slice_of_concat(["N", "N"], "-one", "-intmax", "-N", "floor(-1/N) + 3")
def test_flip(self):
self.check_slice_of_concat(["N"], "-one", "-intmax", "-one", "N")
def test_flip_of_concat(self):
self.check_slice_of_concat(["N", "N", "N"], "-one", "-intmax", "-one", "3*N")
if __name__ == '__main__':
unittest.main()