mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
handle Floor and SplitToSequence (#4384)
* handle Floor and SplitToSequence added support to Floor and SplitToSequence ops * Address CR use sympy.floor for computation on Floor
This commit is contained in:
parent
473cd5545f
commit
010445fc52
1 changed files with 18 additions and 5 deletions
|
|
@ -86,6 +86,7 @@ class SymbolicShapeInference:
|
|||
'Div' : self._infer_symbolic_compute_ops,
|
||||
'Expand' : self._infer_Expand,
|
||||
'Equal' : self._infer_symbolic_compute_ops,
|
||||
'Floor' : self._infer_symbolic_compute_ops,
|
||||
'Gather' : self._infer_Gather,
|
||||
'GatherElements' : self._infer_GatherElements,
|
||||
'GatherND' : self._infer_GatherND,
|
||||
|
|
@ -112,6 +113,7 @@ class SymbolicShapeInference:
|
|||
'Size' : self._infer_Size,
|
||||
'Slice' : self._infer_Slice,
|
||||
'Split' : self._infer_Split,
|
||||
'SplitToSequence' : self._infer_SplitToSequence,
|
||||
'Squeeze' : self._infer_Squeeze,
|
||||
'Sub' : self._infer_symbolic_compute_ops,
|
||||
'Tile' : self._infer_Tile,
|
||||
|
|
@ -331,10 +333,14 @@ class SymbolicShapeInference:
|
|||
# run single node inference with self.known_vi_ shapes
|
||||
# note that inference rely on initializer values is not handled
|
||||
# as we don't copy initializer weights to tmp_graph for inference speed purpose
|
||||
if node.op_type == 'SplitToSequence':
|
||||
make_value_info_func = helper.make_sequence_value_info
|
||||
else:
|
||||
make_value_info_func = helper.make_tensor_value_info
|
||||
tmp_graph = helper.make_graph([node],
|
||||
'tmp',
|
||||
[self.known_vi_[i] for i in node.input if i],
|
||||
[helper.make_tensor_value_info(i, onnx.TensorProto.UNDEFINED, None) for i in node.output])
|
||||
[make_value_info_func(i, onnx.TensorProto.UNDEFINED, None) for i in node.output])
|
||||
self.tmp_mp_.graph.CopyFrom(tmp_graph)
|
||||
self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
|
||||
for i_o in range(len(node.output)):
|
||||
|
|
@ -558,6 +564,7 @@ class SymbolicShapeInference:
|
|||
funcs = {'Add' : lambda l: l[0] + l[1],
|
||||
'Div' : lambda l: l[0] // l[1], # integer div in sympy
|
||||
'Equal' : lambda l : l[0] == l[1],
|
||||
'Floor' : lambda l : sympy.floor(l[0]),
|
||||
'Max' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
|
||||
'Min' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])),
|
||||
'Mul' : lambda l: l[0] * l[1],
|
||||
|
|
@ -1027,7 +1034,7 @@ class SymbolicShapeInference:
|
|||
assert len(ends) == 1
|
||||
self.sympy_data_[node.output[0]] = self.sympy_data_[node.input[0]][starts[0]:ends[0]]
|
||||
|
||||
def _infer_Split(self, node):
|
||||
def _infer_Split_Common(self, node, make_value_info_func):
|
||||
input_sympy_shape = self._get_sympy_shape(node, 0)
|
||||
axis = handle_negative_axis(get_attribute(node, 'axis', 0), len(input_sympy_shape))
|
||||
split = get_attribute(node, 'split')
|
||||
|
|
@ -1040,11 +1047,17 @@ class SymbolicShapeInference:
|
|||
|
||||
for i_o in range(len(split)):
|
||||
vi = self.known_vi_[node.output[i_o]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o],
|
||||
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
||||
get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis+1:])))
|
||||
vi.CopyFrom(make_value_info_func(node.output[i_o],
|
||||
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
||||
get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis+1:])))
|
||||
self.known_vi_[vi.name] = vi
|
||||
|
||||
def _infer_Split(self, node):
|
||||
self._infer_Split_Common(node, helper.make_tensor_value_info)
|
||||
|
||||
def _infer_SplitToSequence(self, node):
|
||||
self._infer_Split_Common(node, helper.make_sequence_value_info)
|
||||
|
||||
def _infer_Squeeze(self, node):
|
||||
self._pass_on_sympy_data(node)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue