mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Add more symbolic compute support in symbolic shape inference (#4057)
* Add more symbolic compute support in symbolic shape inference * Refinements
This commit is contained in:
parent
2a96be83f6
commit
348ed698ec
1 changed files with 15 additions and 8 deletions
|
|
@ -73,7 +73,7 @@ def sympy_reduce_product(x):
|
|||
class SymbolicShapeInference:
|
||||
def __init__(self, int_max, auto_merge, guess_output_rank, verbose):
|
||||
self.dispatcher_ = {
|
||||
'Add' : self._infer_binary_ops,
|
||||
'Add' : self._infer_symbolic_compute_ops,
|
||||
'ArrayFeatureExtractor' : self._infer_ArrayFeatureExtractor,
|
||||
'AveragePool' : self._infer_Pool,
|
||||
'Cast' : self._infer_Cast,
|
||||
|
|
@ -83,8 +83,9 @@ class SymbolicShapeInference:
|
|||
'ConstantOfShape' : self._infer_ConstantOfShape,
|
||||
'Conv' : self._infer_Conv,
|
||||
'CumSum' : self._pass_on_shape_and_type,
|
||||
'Div' : self._infer_binary_ops,
|
||||
'Div' : self._infer_symbolic_compute_ops,
|
||||
'Expand' : self._infer_Expand,
|
||||
'Equal' : self._infer_symbolic_compute_ops,
|
||||
'Gather' : self._infer_Gather,
|
||||
'GatherElements' : self._infer_GatherElements,
|
||||
'GatherND' : self._infer_GatherND,
|
||||
|
|
@ -93,9 +94,9 @@ class SymbolicShapeInference:
|
|||
'MatMul' : self._infer_MatMul,
|
||||
'MatMulInteger16' : self._infer_MatMulInteger,
|
||||
'MaxPool' : self._infer_Pool,
|
||||
'Max' : self._infer_binary_ops,
|
||||
'Min' : self._infer_binary_ops,
|
||||
'Mul' : self._infer_binary_ops,
|
||||
'Max' : self._infer_symbolic_compute_ops,
|
||||
'Min' : self._infer_symbolic_compute_ops,
|
||||
'Mul' : self._infer_symbolic_compute_ops,
|
||||
'NonMaxSuppression' : self._infer_NonMaxSuppression,
|
||||
'NonZero' : self._infer_NonZero,
|
||||
'OneHot' : self._infer_OneHot,
|
||||
|
|
@ -112,10 +113,11 @@ class SymbolicShapeInference:
|
|||
'Slice' : self._infer_Slice,
|
||||
'Split' : self._infer_Split,
|
||||
'Squeeze' : self._infer_Squeeze,
|
||||
'Sub' : self._infer_binary_ops,
|
||||
'Sub' : self._infer_symbolic_compute_ops,
|
||||
'Tile' : self._infer_Tile,
|
||||
'TopK' : self._infer_TopK,
|
||||
'Unsqueeze' : self._infer_Unsqueeze,
|
||||
'Where' : self._infer_symbolic_compute_ops,
|
||||
'ZipMap' : self._infer_ZipMap}
|
||||
self.run_ = True
|
||||
self.suggested_merge_ = {}
|
||||
|
|
@ -552,13 +554,15 @@ class SymbolicShapeInference:
|
|||
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
||||
data_shape[:-1] + indices_shape))
|
||||
|
||||
def _infer_binary_ops(self, node):
|
||||
def _infer_symbolic_compute_ops(self, node):
|
||||
funcs = {'Add' : lambda l: l[0] + l[1],
|
||||
'Div' : lambda l: l[0] // l[1], # integer div in sympy
|
||||
'Equal' : lambda l : l[0] == l[1],
|
||||
'Max' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
|
||||
'Min' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])),
|
||||
'Mul' : lambda l: l[0] * l[1],
|
||||
'Sub' : lambda l: l[0] - l[1]}
|
||||
'Sub' : lambda l: l[0] - l[1],
|
||||
'Where' : lambda l: l[1] if l[0] else l[2]}
|
||||
assert node.op_type in funcs
|
||||
self._compute_on_sympy_data(node, funcs[node.op_type])
|
||||
|
||||
|
|
@ -638,6 +642,9 @@ class SymbolicShapeInference:
|
|||
if type(sympy_shape) != list:
|
||||
sympy_shape = [sympy_shape]
|
||||
self._update_computed_dims(sympy_shape)
|
||||
# update sympy data if output type is int, and shape is known
|
||||
if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]):
|
||||
self.sympy_data_[node.output[0]] = np.ones([int(x) for x in sympy_shape], dtype=np.int64) * numpy_helper.to_array(get_attribute(node, 'value', 0))
|
||||
else:
|
||||
# create new dynamic shape
|
||||
sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node,0), node)
|
||||
|
|
|
|||
Loading…
Reference in a new issue