diff --git a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py index 12263161a6..0f8e1d2a11 100755 --- a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py +++ b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py @@ -86,6 +86,7 @@ class SymbolicShapeInference: 'Div' : self._infer_symbolic_compute_ops, 'Expand' : self._infer_Expand, 'Equal' : self._infer_symbolic_compute_ops, + 'Floor' : self._infer_symbolic_compute_ops, 'Gather' : self._infer_Gather, 'GatherElements' : self._infer_GatherElements, 'GatherND' : self._infer_GatherND, @@ -112,6 +113,7 @@ class SymbolicShapeInference: 'Size' : self._infer_Size, 'Slice' : self._infer_Slice, 'Split' : self._infer_Split, + 'SplitToSequence' : self._infer_SplitToSequence, 'Squeeze' : self._infer_Squeeze, 'Sub' : self._infer_symbolic_compute_ops, 'Tile' : self._infer_Tile, @@ -331,10 +333,14 @@ class SymbolicShapeInference: # run single node inference with self.known_vi_ shapes # note that inference rely on initializer values is not handled # as we don't copy initializer weights to tmp_graph for inference speed purpose + if node.op_type == 'SplitToSequence': + make_value_info_func = helper.make_sequence_value_info + else: + make_value_info_func = helper.make_tensor_value_info tmp_graph = helper.make_graph([node], 'tmp', [self.known_vi_[i] for i in node.input if i], - [helper.make_tensor_value_info(i, onnx.TensorProto.UNDEFINED, None) for i in node.output]) + [make_value_info_func(i, onnx.TensorProto.UNDEFINED, None) for i in node.output]) self.tmp_mp_.graph.CopyFrom(tmp_graph) self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) for i_o in range(len(node.output)): @@ -558,6 +564,7 @@ class SymbolicShapeInference: funcs = {'Add' : lambda l: l[0] + l[1], 'Div' : lambda l: l[0] // l[1], # integer div in sympy 'Equal' : lambda l : l[0] == l[1], + 'Floor' : lambda l : sympy.floor(l[0]), 'Max' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), 'Min' : lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])), 'Mul' : lambda l: l[0] * l[1], @@ -1027,7 +1034,7 @@ class SymbolicShapeInference: assert len(ends) == 1 self.sympy_data_[node.output[0]] = self.sympy_data_[node.input[0]][starts[0]:ends[0]] - def _infer_Split(self, node): + def _infer_Split_Common(self, node, make_value_info_func): input_sympy_shape = self._get_sympy_shape(node, 0) axis = handle_negative_axis(get_attribute(node, 'axis', 0), len(input_sympy_shape)) split = get_attribute(node, 'split') @@ -1040,11 +1047,17 @@ class SymbolicShapeInference: for i_o in range(len(split)): vi = self.known_vi_[node.output[i_o]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis+1:]))) + vi.CopyFrom(make_value_info_func(node.output[i_o], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis+1:]))) self.known_vi_[vi.name] = vi + def _infer_Split(self, node): + self._infer_Split_Common(node, helper.make_tensor_value_info) + + def _infer_SplitToSequence(self, node): + self._infer_Split_Common(node, helper.make_sequence_value_info) + def _infer_Squeeze(self, node): self._pass_on_sympy_data(node)