mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Some bug fixes and support for Gather/ScatterElements (#1940)
Fix a bug in Concat when only part of input has sympy_data Fix a bug in ConstantOfShape when shape is scalar Add support for GatherElements and ScatterElements
This commit is contained in:
parent
ceaaff0f81
commit
1649374a5c
1 changed files with 30 additions and 12 deletions
|
|
@ -72,6 +72,7 @@ class SymbolicShapeInference:
|
|||
'Div' : self._infer_binary_ops,
|
||||
'Expand' : self._infer_Expand,
|
||||
'Gather' : self._infer_Gather,
|
||||
'GatherElements' : self._infer_GatherElements,
|
||||
'Loop' : self._infer_Loop,
|
||||
'MaxPool' : self._infer_Pool,
|
||||
'Max' : self._infer_binary_ops,
|
||||
|
|
@ -87,6 +88,7 @@ class SymbolicShapeInference:
|
|||
'Resize' : self._infer_Resize,
|
||||
'Round' : self._pass_on_shape_and_type,
|
||||
'Scan' : self._infer_Scan,
|
||||
'ScatterElements' : self._infer_ScatterElements,
|
||||
'Shape' : self._infer_Shape,
|
||||
'Size' : self._infer_Size,
|
||||
'Slice' : self._infer_Slice,
|
||||
|
|
@ -451,15 +453,15 @@ class SymbolicShapeInference:
|
|||
def _infer_Concat(self, node):
|
||||
if any([i in self.sympy_data_ for i in node.input]):
|
||||
values = self._get_int_values(node)
|
||||
assert all([v is not None for v in values])
|
||||
assert 0 == get_attribute(node, 'axis')
|
||||
self.sympy_data_[node.output[0]] = []
|
||||
for i in range(len(node.input)):
|
||||
value = values[i]
|
||||
if type(value) == list:
|
||||
self.sympy_data_[node.output[0]].extend(value)
|
||||
else:
|
||||
self.sympy_data_[node.output[0]].append(value)
|
||||
if all([v is not None for v in values]):
|
||||
assert 0 == get_attribute(node, 'axis')
|
||||
self.sympy_data_[node.output[0]] = []
|
||||
for i in range(len(node.input)):
|
||||
value = values[i]
|
||||
if type(value) == list:
|
||||
self.sympy_data_[node.output[0]].extend(value)
|
||||
else:
|
||||
self.sympy_data_[node.output[0]].append(value)
|
||||
|
||||
sympy_shape = self._get_sympy_shape(node, 0)
|
||||
axis = handle_negative_axis(get_attribute(node, 'axis'), len(sympy_shape))
|
||||
|
|
@ -488,12 +490,14 @@ class SymbolicShapeInference:
|
|||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape)))
|
||||
|
||||
def _infer_ConstantOfShape(self, node):
|
||||
sympy_shape = self._get_value(node, 0)
|
||||
if sympy_shape:
|
||||
sympy_shape = self._get_int_values(node)[0]
|
||||
if sympy_shape is not None:
|
||||
if type(sympy_shape) != list:
|
||||
sympy_shape = [sympy_shape]
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
|
||||
vi.type.tensor_type.elem_type,
|
||||
[int(i) if is_literal(i) else str(i) for i in sympy_shape]))
|
||||
get_shape_from_sympy_shape(sympy_shape)))
|
||||
|
||||
def _infer_Expand(self, node):
|
||||
expand_to_shape = self._try_get_value(node, 1)
|
||||
|
|
@ -525,6 +529,13 @@ class SymbolicShapeInference:
|
|||
assert idx == 0
|
||||
self.sympy_data_[node.output[0]] = data
|
||||
|
||||
def _infer_GatherElements(self, node):
|
||||
indices_shape = self._get_shape(node, 1)
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
|
||||
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
||||
indices_shape))
|
||||
|
||||
def _infer_Loop(self, node):
|
||||
subgraph = get_attribute(node, 'body')
|
||||
assert len(subgraph.input) == len(node.input)
|
||||
|
|
@ -691,6 +702,13 @@ class SymbolicShapeInference:
|
|||
vi.CopyFrom(subgraph.output[i])
|
||||
vi.name = o
|
||||
|
||||
def _infer_ScatterElements(self, node):
|
||||
data_shape = self._get_shape(node, 0)
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
|
||||
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
||||
data_shape))
|
||||
|
||||
def _infer_Shape(self, node):
|
||||
self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue