Add more symbolic compute support in symbolic shape inference (#4057)

* Add more symbolic compute support in symbolic shape inference

* Refinements
This commit is contained in:
KeDengMS 2020-05-29 02:00:30 -07:00 committed by GitHub
parent 2a96be83f6
commit 348ed698ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -73,7 +73,7 @@ def sympy_reduce_product(x):
class SymbolicShapeInference:
def __init__(self, int_max, auto_merge, guess_output_rank, verbose):
self.dispatcher_ = {
'Add' : self._infer_binary_ops,
'Add' : self._infer_symbolic_compute_ops,
'ArrayFeatureExtractor' : self._infer_ArrayFeatureExtractor,
'AveragePool' : self._infer_Pool,
'Cast' : self._infer_Cast,
@ -83,8 +83,9 @@ class SymbolicShapeInference:
'ConstantOfShape' : self._infer_ConstantOfShape,
'Conv' : self._infer_Conv,
'CumSum' : self._pass_on_shape_and_type,
'Div' : self._infer_binary_ops,
'Div' : self._infer_symbolic_compute_ops,
'Expand' : self._infer_Expand,
'Equal' : self._infer_symbolic_compute_ops,
'Gather' : self._infer_Gather,
'GatherElements' : self._infer_GatherElements,
'GatherND' : self._infer_GatherND,
@ -93,9 +94,9 @@ class SymbolicShapeInference:
'MatMul' : self._infer_MatMul,
'MatMulInteger16' : self._infer_MatMulInteger,
'MaxPool' : self._infer_Pool,
'Max' : self._infer_binary_ops,
'Min' : self._infer_binary_ops,
'Mul' : self._infer_binary_ops,
'Max' : self._infer_symbolic_compute_ops,
'Min' : self._infer_symbolic_compute_ops,
'Mul' : self._infer_symbolic_compute_ops,
'NonMaxSuppression' : self._infer_NonMaxSuppression,
'NonZero' : self._infer_NonZero,
'OneHot' : self._infer_OneHot,
@ -112,10 +113,11 @@ class SymbolicShapeInference:
'Slice' : self._infer_Slice,
'Split' : self._infer_Split,
'Squeeze' : self._infer_Squeeze,
'Sub' : self._infer_binary_ops,
'Sub' : self._infer_symbolic_compute_ops,
'Tile' : self._infer_Tile,
'TopK' : self._infer_TopK,
'Unsqueeze' : self._infer_Unsqueeze,
'Where' : self._infer_symbolic_compute_ops,
'ZipMap' : self._infer_ZipMap}
self.run_ = True
self.suggested_merge_ = {}
@ -552,13 +554,15 @@ class SymbolicShapeInference:
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
data_shape[:-1] + indices_shape))
def _infer_binary_ops(self, node):
def _infer_symbolic_compute_ops(self, node):
funcs = {'Add' : lambda l: l[0] + l[1],
'Div' : lambda l: l[0] // l[1], # integer div in sympy
'Equal' : lambda l : l[0] == l[1],
'Max' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
'Min' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])),
'Mul' : lambda l: l[0] * l[1],
'Sub' : lambda l: l[0] - l[1]}
'Sub' : lambda l: l[0] - l[1],
'Where' : lambda l: l[1] if l[0] else l[2]}
assert node.op_type in funcs
self._compute_on_sympy_data(node, funcs[node.op_type])
@ -638,6 +642,9 @@ class SymbolicShapeInference:
if type(sympy_shape) != list:
sympy_shape = [sympy_shape]
self._update_computed_dims(sympy_shape)
# update sympy data if output type is int, and shape is known
if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]):
self.sympy_data_[node.output[0]] = np.ones([int(x) for x in sympy_shape], dtype=np.int64) * numpy_helper.to_array(get_attribute(node, 'value', 0))
else:
# create new dynamic shape
sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node,0), node)