mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-15 01:23:42 +00:00
* Use positivity everywhere; handle negative index in Slice * limit positivity to inputs * make handle_negative_index private * strengthen sympy comparison * further strengthen compariso n and a minor refactoring * Add flip test * Fall through if -int_max in handle_negative_index() * minor fix for infer_Concat to include initializers * Add more tests * use simplify * more tests
115 lines
4.7 KiB
Python
115 lines
4.7 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
# -*- coding: UTF-8 -*-
|
|
import onnx
|
|
from onnx import AttributeProto, TensorProto, GraphProto
|
|
import os
|
|
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
|
from pathlib import Path
|
|
import unittest
|
|
|
|
def unique_element(lst):
|
|
assert len(lst) == 1
|
|
return lst[0]
|
|
|
|
class TestSymbolicShapeInference(unittest.TestCase):
|
|
def test_symbolic_shape_infer(self):
|
|
cwd = os.getcwd()
|
|
test_model_dir = os.path.join(cwd, '..', 'models')
|
|
for filename in Path(test_model_dir).rglob('*.onnx'):
|
|
if filename.name.startswith('.'):
|
|
continue # skip some bad model files
|
|
print("Running symbolic shape inference on : " + str(filename))
|
|
SymbolicShapeInference.infer_shapes(
|
|
in_mp=onnx.load(str(filename)),
|
|
auto_merge=True,
|
|
int_max=100000,
|
|
guess_output_rank=True)
|
|
|
|
class TestSymbolicShapeInferenceForSlice(unittest.TestCase):
|
|
def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):
|
|
_dimstrmap = {dim: f"dim{i}" for i, dim in enumerate(input_dims)}
|
|
def dimstrmap(dim):
|
|
return _dimstrmap.get(dim, dim)
|
|
def get_initializer(name):
|
|
valuemap = {"zero": 0, "one": 1, "two": 2, "ten": 10, "intmax": 2**32}
|
|
value = -valuemap[name[4:]] if name.startswith("neg_") else valuemap[name]
|
|
return onnx.helper.make_tensor(name, TensorProto.INT64, [1], [value])
|
|
|
|
initializers = [get_initializer(name) for name in ["zero", "one", "two", "ten", "intmax", "neg_intmax", "neg_one", "neg_ten"]]
|
|
inputs = []
|
|
nodes = []
|
|
for i, dim in enumerate(input_dims):
|
|
inputs.append(onnx.helper.make_tensor_value_info(f"t{i}", TensorProto.FLOAT, ["B", dim]))
|
|
nodes.extend(
|
|
[
|
|
onnx.helper.make_node("Shape", [f"t{i}"], [f"shape{i}"]),
|
|
onnx.helper.make_node(
|
|
"Slice",
|
|
[f"shape{i}", "one", "two", "zero", "one"],
|
|
[f"dim{i}"]
|
|
),
|
|
onnx.helper.make_node("Neg", [f"dim{i}"], [f"neg_dim{i}"])
|
|
]
|
|
)
|
|
|
|
def make_concat_dims(concat_name, dims):
|
|
dims = [
|
|
f"neg_{dimstrmap(dim[1:])}" if dim.startswith("-") else dimstrmap(dim) for dim in dims
|
|
]
|
|
return onnx.helper.make_node("Concat", dims, [concat_name], axis=0)
|
|
|
|
nodes.extend(
|
|
[
|
|
onnx.helper.make_node("Concat", [inp.name for inp in inputs], ["concat"], axis=1),
|
|
make_concat_dims("starts", ["zero", start]),
|
|
make_concat_dims("ends", ["intmax", end]),
|
|
make_concat_dims("axes", ["zero", "one"]),
|
|
make_concat_dims("steps", ["one", step]),
|
|
onnx.helper.make_node("Slice", ["concat", "starts", "ends", "axes", "steps"], ["output"])
|
|
]
|
|
)
|
|
output = onnx.helper.make_tensor_value_info("output", TensorProto.FLOAT, ["d1", "d2"])
|
|
graph_def = onnx.helper.make_graph(
|
|
nodes,
|
|
"graph",
|
|
inputs,
|
|
[output],
|
|
initializer=initializers
|
|
)
|
|
model = SymbolicShapeInference.infer_shapes(onnx.helper.make_model(graph_def))
|
|
output = unique_element(model.graph.output)
|
|
shape = [d.dim_param if d.dim_param else d.dim_value for d in output.type.tensor_type.shape.dim]
|
|
self.assertEqual(shape, ["B", expected_output_dim])
|
|
|
|
def test_numeric_negative_indices_forward(self):
|
|
self.check_slice_of_concat(["M"], "-ten", "-one", "one", 9)
|
|
|
|
def test_numeric_negative_indices_backward(self):
|
|
self.check_slice_of_concat(["M"], "-one", "-ten", "-one", 9)
|
|
|
|
def test_symbolic_end_index(self):
|
|
self.check_slice_of_concat(["M", "N"], "zero", "M", "one", "M")
|
|
|
|
def test_symbolic_negative_start_index(self):
|
|
self.check_slice_of_concat(["M", "N"], "-N", "intmax", "one", "N")
|
|
|
|
def test_non_unit_step(self):
|
|
self.check_slice_of_concat(["N", "N"], "zero", "intmax", "two", "N")
|
|
|
|
def test_symbolic_step(self):
|
|
self.check_slice_of_concat(["N", "N"], "zero", "intmax", "N", "floor(-1/N) + 3")
|
|
|
|
def test_symbolic_negative_step(self):
|
|
self.check_slice_of_concat(["N", "N"], "-one", "-intmax", "-N", "floor(-1/N) + 3")
|
|
|
|
def test_flip(self):
|
|
self.check_slice_of_concat(["N"], "-one", "-intmax", "-one", "N")
|
|
|
|
def test_flip_of_concat(self):
|
|
self.check_slice_of_concat(["N", "N", "N"], "-one", "-intmax", "-one", "3*N")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|