From 348ed698ec70c8e02c375afdf4f6c1cffb3ad887 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Fri, 29 May 2020 02:00:30 -0700 Subject: [PATCH] Add more symbolic compute support in symbolic shape inference (#4057) * Add more symbolic compute support in symbolic shape inference * Refinements --- .../nuphar/scripts/symbolic_shape_infer.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py index 9fac020e4a..2a339b9c28 100755 --- a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py +++ b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py @@ -73,7 +73,7 @@ def sympy_reduce_product(x): class SymbolicShapeInference: def __init__(self, int_max, auto_merge, guess_output_rank, verbose): self.dispatcher_ = { - 'Add' : self._infer_binary_ops, + 'Add' : self._infer_symbolic_compute_ops, 'ArrayFeatureExtractor' : self._infer_ArrayFeatureExtractor, 'AveragePool' : self._infer_Pool, 'Cast' : self._infer_Cast, @@ -83,8 +83,9 @@ class SymbolicShapeInference: 'ConstantOfShape' : self._infer_ConstantOfShape, 'Conv' : self._infer_Conv, 'CumSum' : self._pass_on_shape_and_type, - 'Div' : self._infer_binary_ops, + 'Div' : self._infer_symbolic_compute_ops, 'Expand' : self._infer_Expand, + 'Equal' : self._infer_symbolic_compute_ops, 'Gather' : self._infer_Gather, 'GatherElements' : self._infer_GatherElements, 'GatherND' : self._infer_GatherND, @@ -93,9 +94,9 @@ class SymbolicShapeInference: 'MatMul' : self._infer_MatMul, 'MatMulInteger16' : self._infer_MatMulInteger, 'MaxPool' : self._infer_Pool, - 'Max' : self._infer_binary_ops, - 'Min' : self._infer_binary_ops, - 'Mul' : self._infer_binary_ops, + 'Max' : self._infer_symbolic_compute_ops, + 'Min' : self._infer_symbolic_compute_ops, + 'Mul' : self._infer_symbolic_compute_ops, 'NonMaxSuppression' : self._infer_NonMaxSuppression, 'NonZero' : self._infer_NonZero, 'OneHot' : self._infer_OneHot, @@ -112,10 +113,11 @@ class SymbolicShapeInference: 'Slice' : self._infer_Slice, 'Split' : self._infer_Split, 'Squeeze' : self._infer_Squeeze, - 'Sub' : self._infer_binary_ops, + 'Sub' : self._infer_symbolic_compute_ops, 'Tile' : self._infer_Tile, 'TopK' : self._infer_TopK, 'Unsqueeze' : self._infer_Unsqueeze, + 'Where' : self._infer_symbolic_compute_ops, 'ZipMap' : self._infer_ZipMap} self.run_ = True self.suggested_merge_ = {} @@ -552,13 +554,15 @@ class SymbolicShapeInference: self.known_vi_[node.input[0]].type.tensor_type.elem_type, data_shape[:-1] + indices_shape)) - def _infer_binary_ops(self, node): + def _infer_symbolic_compute_ops(self, node): funcs = {'Add' : lambda l: l[0] + l[1], 'Div' : lambda l: l[0] // l[1], # integer div in sympy + 'Equal' : lambda l : l[0] == l[1], 'Max' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), 'Min' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])), 'Mul' : lambda l: l[0] * l[1], - 'Sub' : lambda l: l[0] - l[1]} + 'Sub' : lambda l: l[0] - l[1], + 'Where' : lambda l: l[1] if l[0] else l[2]} assert node.op_type in funcs self._compute_on_sympy_data(node, funcs[node.op_type]) @@ -638,6 +642,9 @@ class SymbolicShapeInference: if type(sympy_shape) != list: sympy_shape = [sympy_shape] self._update_computed_dims(sympy_shape) + # update sympy data if output type is int, and shape is known + if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]): + self.sympy_data_[node.output[0]] = np.ones([int(x) for x in sympy_shape], dtype=np.int64) * numpy_helper.to_array(get_attribute(node, 'value', 0)) else: # create new dynamic shape sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node,0), node)