From 1649374a5c551d3ba873fb309b76dfdce9e3c10c Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Thu, 26 Sep 2019 21:30:29 -0700 Subject: [PATCH] Some bug fixes and support for Gather/ScatterElements (#1940) Fix a bug in Concat when only part of input has sympy_data Fix a bug in ConstantOfShape when shape is scalar Add support for GatherElements and ScatterElements --- .../nuphar/scripts/symbolic_shape_infer.py | 42 +++++++++++++------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py index 6c181ce161..41473272ae 100644 --- a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py +++ b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py @@ -72,6 +72,7 @@ class SymbolicShapeInference: 'Div' : self._infer_binary_ops, 'Expand' : self._infer_Expand, 'Gather' : self._infer_Gather, + 'GatherElements' : self._infer_GatherElements, 'Loop' : self._infer_Loop, 'MaxPool' : self._infer_Pool, 'Max' : self._infer_binary_ops, @@ -87,6 +88,7 @@ class SymbolicShapeInference: 'Resize' : self._infer_Resize, 'Round' : self._pass_on_shape_and_type, 'Scan' : self._infer_Scan, + 'ScatterElements' : self._infer_ScatterElements, 'Shape' : self._infer_Shape, 'Size' : self._infer_Size, 'Slice' : self._infer_Slice, @@ -451,15 +453,15 @@ class SymbolicShapeInference: def _infer_Concat(self, node): if any([i in self.sympy_data_ for i in node.input]): values = self._get_int_values(node) - assert all([v is not None for v in values]) - assert 0 == get_attribute(node, 'axis') - self.sympy_data_[node.output[0]] = [] - for i in range(len(node.input)): - value = values[i] - if type(value) == list: - self.sympy_data_[node.output[0]].extend(value) - else: - self.sympy_data_[node.output[0]].append(value) + if all([v is not None for v in values]): + assert 0 == get_attribute(node, 'axis') + self.sympy_data_[node.output[0]] = [] + for i in range(len(node.input)): + value = values[i] + if type(value) == list: + self.sympy_data_[node.output[0]].extend(value) + else: + self.sympy_data_[node.output[0]].append(value) sympy_shape = self._get_sympy_shape(node, 0) axis = handle_negative_axis(get_attribute(node, 'axis'), len(sympy_shape)) @@ -488,12 +490,14 @@ class SymbolicShapeInference: vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape))) def _infer_ConstantOfShape(self, node): - sympy_shape = self._get_value(node, 0) - if sympy_shape: + sympy_shape = self._get_int_values(node)[0] + if sympy_shape is not None: + if type(sympy_shape) != list: + sympy_shape = [sympy_shape] vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, - [int(i) if is_literal(i) else str(i) for i in sympy_shape])) + get_shape_from_sympy_shape(sympy_shape))) def _infer_Expand(self, node): expand_to_shape = self._try_get_value(node, 1) @@ -525,6 +529,13 @@ class SymbolicShapeInference: assert idx == 0 self.sympy_data_[node.output[0]] = data + def _infer_GatherElements(self, node): + indices_shape = self._get_shape(node, 1) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + indices_shape)) + def _infer_Loop(self, node): subgraph = get_attribute(node, 'body') assert len(subgraph.input) == len(node.input) @@ -691,6 +702,13 @@ class SymbolicShapeInference: vi.CopyFrom(subgraph.output[i]) vi.name = o + def _infer_ScatterElements(self, node): + data_shape = self._get_shape(node, 0) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape)) + def _infer_Shape(self, node): self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)