diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 11e87edec2..ab513a3050 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -138,6 +138,7 @@ class SymbolicShapeInference: 'Unsqueeze': self._infer_Unsqueeze, 'Where': self._infer_symbolic_compute_ops, 'ZipMap': self._infer_ZipMap, + 'Neg': self._infer_symbolic_compute_ops, # contrib ops: 'Attention': self._infer_Attention, 'BiasGelu': self._infer_BiasGelu, @@ -292,7 +293,7 @@ class SymbolicShapeInference: for d in self._get_shape(node, idx): if type(d) == str: sympy_shape.append(self.symbolic_dims_[d] if d in - self.symbolic_dims_ else sympy.Symbol(d, integer=True)) + self.symbolic_dims_ else sympy.Symbol(d, integer=True, nonnegative=True)) else: assert None != d sympy_shape.append(d) @@ -451,7 +452,7 @@ class SymbolicShapeInference: v = self.suggested_merge_[new_dim] new_dim = sympy.Integer(int(v)) if is_literal(v) else v else: - self.symbolic_dims_[new_dim] = sympy.Symbol(new_dim, integer=True) + self.symbolic_dims_[new_dim] = sympy.Symbol(new_dim, integer=True, nonnegative=True) return new_dim def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): @@ -587,7 +588,9 @@ class SymbolicShapeInference: 'Sub': lambda l: l[0] - l[1], 'Where': - lambda l: l[1] if l[0] else l[2] + lambda l: l[1] if l[0] else l[2], + 'Neg': + lambda l: -l[0] } assert node.op_type in funcs self._compute_on_sympy_data(node, funcs[node.op_type]) @@ -621,7 +624,7 @@ class SymbolicShapeInference: output_shape)) def _infer_Concat(self, node): - if any([i in self.sympy_data_ for i in node.input]): + if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]): values = self._get_int_values(node) if all([v is not None for v in values]): assert 0 == get_attribute(node, 'axis') @@ -1029,6 +1032,37 @@ class SymbolicShapeInference: helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) def _infer_Slice(self, node): + def less_equal(x, y): + try: + return bool(x <= y) + except TypeError: + pass + try: + return bool(y >= x) + except TypeError: + pass + try: + return bool(-x >= -y) + except TypeError: + pass + try: + return bool(-y <= -x) + except TypeError: + # the last attempt; this may raise TypeError + return bool(y - x >= 0) + + def handle_negative_index(index, bound): + """ normalizes a negative index to be in [0, bound) """ + try: + if not less_equal(0, index): + if is_literal(index) and index <= -self.int_max_: + # this case is handled separately + return index + return bound + index + except TypeError: + print("Cannot determine if {} < 0".format(index)) + return index + if get_opset(self.out_mp_) <= 9: axes = get_attribute(node, 'axes') starts = get_attribute(node, 'starts') @@ -1059,6 +1093,7 @@ class SymbolicShapeInference: new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) else: for i, s, e, t in zip(axes, starts, ends, steps): + e = handle_negative_index(e, new_sympy_shape[i]) if is_literal(e): if e >= self.int_max_: e = new_sympy_shape[i] @@ -1072,25 +1107,22 @@ class SymbolicShapeInference: if e > 0: e = sympy.Min(e, new_sympy_shape[i] ) if e > 1 else e #special case for slicing first to make computation easier - else: - e = new_sympy_shape[i] + e else: if is_literal(new_sympy_shape[i]): e = sympy.Min(e, new_sympy_shape[i]) else: try: - if (e - new_sympy_shape[i]) >= 0: + if not less_equal(e, new_sympy_shape[i]): e = new_sympy_shape[i] except Exception: print('Unable to determine if {} <= {}, treat as equal'.format(e, new_sympy_shape[i])) e = new_sympy_shape[i] - if is_literal(s) and int(s) < 0: - s = new_sympy_shape[i] + s + s = handle_negative_index(s, new_sympy_shape[i]) if is_literal(new_sympy_shape[i]) and is_literal(s): s = max(0, min(s, new_sympy_shape[i])) - new_sympy_shape[i] = (e - s + t + (-1 if t > 0 else 1)) // t + new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t) self._update_computed_dims(new_sympy_shape) @@ -1302,8 +1334,8 @@ class SymbolicShapeInference: assert s_merge in self.symbolic_dims_ self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] else: - self.symbolic_dims_[s] = sympy.Symbol(s, integer=True) - + # Since inputs are not produced by other ops, we can assume positivity + self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True) # create a temporary ModelProto for single node inference # note that we remove initializer to have faster inference # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index beee92474a..1806f7c9db 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -3,11 +3,16 @@ # -*- coding: UTF-8 -*- import onnx +from onnx import AttributeProto, TensorProto, GraphProto import os from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from pathlib import Path import unittest +def unique_element(lst): + assert len(lst) == 1 + return lst[0] + class TestSymbolicShapeInference(unittest.TestCase): def test_symbolic_shape_infer(self): cwd = os.getcwd() @@ -22,5 +27,89 @@ class TestSymbolicShapeInference(unittest.TestCase): int_max=100000, guess_output_rank=True) +class TestSymbolicShapeInferenceForSlice(unittest.TestCase): + def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim): + _dimstrmap = {dim: f"dim{i}" for i, dim in enumerate(input_dims)} + def dimstrmap(dim): + return _dimstrmap.get(dim, dim) + def get_initializer(name): + valuemap = {"zero": 0, "one": 1, "two": 2, "ten": 10, "intmax": 2**32} + value = -valuemap[name[4:]] if name.startswith("neg_") else valuemap[name] + return onnx.helper.make_tensor(name, TensorProto.INT64, [1], [value]) + + initializers = [get_initializer(name) for name in ["zero", "one", "two", "ten", "intmax", "neg_intmax", "neg_one", "neg_ten"]] + inputs = [] + nodes = [] + for i, dim in enumerate(input_dims): + inputs.append(onnx.helper.make_tensor_value_info(f"t{i}", TensorProto.FLOAT, ["B", dim])) + nodes.extend( + [ + onnx.helper.make_node("Shape", [f"t{i}"], [f"shape{i}"]), + onnx.helper.make_node( + "Slice", + [f"shape{i}", "one", "two", "zero", "one"], + [f"dim{i}"] + ), + onnx.helper.make_node("Neg", [f"dim{i}"], [f"neg_dim{i}"]) + ] + ) + + def make_concat_dims(concat_name, dims): + dims = [ + f"neg_{dimstrmap(dim[1:])}" if dim.startswith("-") else dimstrmap(dim) for dim in dims + ] + return onnx.helper.make_node("Concat", dims, [concat_name], axis=0) + + nodes.extend( + [ + onnx.helper.make_node("Concat", [inp.name for inp in inputs], ["concat"], axis=1), + make_concat_dims("starts", ["zero", start]), + make_concat_dims("ends", ["intmax", end]), + make_concat_dims("axes", ["zero", "one"]), + make_concat_dims("steps", ["one", step]), + onnx.helper.make_node("Slice", ["concat", "starts", "ends", "axes", "steps"], ["output"]) + ] + ) + output = onnx.helper.make_tensor_value_info("output", TensorProto.FLOAT, ["d1", "d2"]) + graph_def = onnx.helper.make_graph( + nodes, + "graph", + inputs, + [output], + initializer=initializers + ) + model = SymbolicShapeInference.infer_shapes(onnx.helper.make_model(graph_def)) + output = unique_element(model.graph.output) + shape = [d.dim_param if d.dim_param else d.dim_value for d in output.type.tensor_type.shape.dim] + self.assertEqual(shape, ["B", expected_output_dim]) + + def test_numeric_negative_indices_forward(self): + self.check_slice_of_concat(["M"], "-ten", "-one", "one", 9) + + def test_numeric_negative_indices_backward(self): + self.check_slice_of_concat(["M"], "-one", "-ten", "-one", 9) + + def test_symbolic_end_index(self): + self.check_slice_of_concat(["M", "N"], "zero", "M", "one", "M") + + def test_symbolic_negative_start_index(self): + self.check_slice_of_concat(["M", "N"], "-N", "intmax", "one", "N") + + def test_non_unit_step(self): + self.check_slice_of_concat(["N", "N"], "zero", "intmax", "two", "N") + + def test_symbolic_step(self): + self.check_slice_of_concat(["N", "N"], "zero", "intmax", "N", "floor(-1/N) + 3") + + def test_symbolic_negative_step(self): + self.check_slice_of_concat(["N", "N"], "-one", "-intmax", "-N", "floor(-1/N) + 3") + + def test_flip(self): + self.check_slice_of_concat(["N"], "-one", "-intmax", "-one", "N") + + def test_flip_of_concat(self): + self.check_slice_of_concat(["N", "N", "N"], "-one", "-intmax", "-one", "3*N") + + if __name__ == '__main__': unittest.main()