handle Floor and SplitToSequence (#4384)

* handle Floor and SplitToSequence

added support to Floor and SplitToSequence ops

* Address CR

use sympy.floor for computation on Floor
This commit is contained in:
Yang Chen 2020-07-01 16:09:43 -07:00 committed by GitHub
parent 473cd5545f
commit 010445fc52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -86,6 +86,7 @@ class SymbolicShapeInference:
'Div' : self._infer_symbolic_compute_ops,
'Expand' : self._infer_Expand,
'Equal' : self._infer_symbolic_compute_ops,
'Floor' : self._infer_symbolic_compute_ops,
'Gather' : self._infer_Gather,
'GatherElements' : self._infer_GatherElements,
'GatherND' : self._infer_GatherND,
@ -112,6 +113,7 @@ class SymbolicShapeInference:
'Size' : self._infer_Size,
'Slice' : self._infer_Slice,
'Split' : self._infer_Split,
'SplitToSequence' : self._infer_SplitToSequence,
'Squeeze' : self._infer_Squeeze,
'Sub' : self._infer_symbolic_compute_ops,
'Tile' : self._infer_Tile,
@ -331,10 +333,14 @@ class SymbolicShapeInference:
# run single node inference with self.known_vi_ shapes
# note that inference rely on initializer values is not handled
# as we don't copy initializer weights to tmp_graph for inference speed purpose
if node.op_type == 'SplitToSequence':
make_value_info_func = helper.make_sequence_value_info
else:
make_value_info_func = helper.make_tensor_value_info
tmp_graph = helper.make_graph([node],
'tmp',
[self.known_vi_[i] for i in node.input if i],
[helper.make_tensor_value_info(i, onnx.TensorProto.UNDEFINED, None) for i in node.output])
[make_value_info_func(i, onnx.TensorProto.UNDEFINED, None) for i in node.output])
self.tmp_mp_.graph.CopyFrom(tmp_graph)
self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
for i_o in range(len(node.output)):
@ -558,6 +564,7 @@ class SymbolicShapeInference:
funcs = {'Add' : lambda l: l[0] + l[1],
'Div' : lambda l: l[0] // l[1], # integer div in sympy
'Equal' : lambda l : l[0] == l[1],
'Floor' : lambda l : sympy.floor(l[0]),
'Max' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
'Min' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])),
'Mul' : lambda l: l[0] * l[1],
@ -1027,7 +1034,7 @@ class SymbolicShapeInference:
assert len(ends) == 1
self.sympy_data_[node.output[0]] = self.sympy_data_[node.input[0]][starts[0]:ends[0]]
def _infer_Split(self, node):
def _infer_Split_Common(self, node, make_value_info_func):
input_sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(get_attribute(node, 'axis', 0), len(input_sympy_shape))
split = get_attribute(node, 'split')
@ -1040,11 +1047,17 @@ class SymbolicShapeInference:
for i_o in range(len(split)):
vi = self.known_vi_[node.output[i_o]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis+1:])))
vi.CopyFrom(make_value_info_func(node.output[i_o],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis+1:])))
self.known_vi_[vi.name] = vi
def _infer_Split(self, node):
self._infer_Split_Common(node, helper.make_tensor_value_info)
def _infer_SplitToSequence(self, node):
self._infer_Split_Common(node, helper.make_sequence_value_info)
def _infer_Squeeze(self, node):
self._pass_on_sympy_data(node)