Some bug fixes and support for Gather/ScatterElements (#1940)

Fix a bug in Concat when only part of input has sympy_data
Fix a bug in ConstantOfShape when shape is scalar
Add support for GatherElements and ScatterElements
This commit is contained in:
KeDengMS 2019-09-26 21:30:29 -07:00 committed by GitHub
parent ceaaff0f81
commit 1649374a5c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -72,6 +72,7 @@ class SymbolicShapeInference:
'Div' : self._infer_binary_ops,
'Expand' : self._infer_Expand,
'Gather' : self._infer_Gather,
'GatherElements' : self._infer_GatherElements,
'Loop' : self._infer_Loop,
'MaxPool' : self._infer_Pool,
'Max' : self._infer_binary_ops,
@ -87,6 +88,7 @@ class SymbolicShapeInference:
'Resize' : self._infer_Resize,
'Round' : self._pass_on_shape_and_type,
'Scan' : self._infer_Scan,
'ScatterElements' : self._infer_ScatterElements,
'Shape' : self._infer_Shape,
'Size' : self._infer_Size,
'Slice' : self._infer_Slice,
@ -451,15 +453,15 @@ class SymbolicShapeInference:
def _infer_Concat(self, node):
if any([i in self.sympy_data_ for i in node.input]):
values = self._get_int_values(node)
assert all([v is not None for v in values])
assert 0 == get_attribute(node, 'axis')
self.sympy_data_[node.output[0]] = []
for i in range(len(node.input)):
value = values[i]
if type(value) == list:
self.sympy_data_[node.output[0]].extend(value)
else:
self.sympy_data_[node.output[0]].append(value)
if all([v is not None for v in values]):
assert 0 == get_attribute(node, 'axis')
self.sympy_data_[node.output[0]] = []
for i in range(len(node.input)):
value = values[i]
if type(value) == list:
self.sympy_data_[node.output[0]].extend(value)
else:
self.sympy_data_[node.output[0]].append(value)
sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(get_attribute(node, 'axis'), len(sympy_shape))
@ -488,12 +490,14 @@ class SymbolicShapeInference:
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape)))
def _infer_ConstantOfShape(self, node):
sympy_shape = self._get_value(node, 0)
if sympy_shape:
sympy_shape = self._get_int_values(node)[0]
if sympy_shape is not None:
if type(sympy_shape) != list:
sympy_shape = [sympy_shape]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
vi.type.tensor_type.elem_type,
[int(i) if is_literal(i) else str(i) for i in sympy_shape]))
get_shape_from_sympy_shape(sympy_shape)))
def _infer_Expand(self, node):
expand_to_shape = self._try_get_value(node, 1)
@ -525,6 +529,13 @@ class SymbolicShapeInference:
assert idx == 0
self.sympy_data_[node.output[0]] = data
def _infer_GatherElements(self, node):
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
indices_shape))
def _infer_Loop(self, node):
subgraph = get_attribute(node, 'body')
assert len(subgraph.input) == len(node.input)
@ -691,6 +702,13 @@ class SymbolicShapeInference:
vi.CopyFrom(subgraph.output[i])
vi.name = o
def _infer_ScatterElements(self, node):
data_shape = self._get_shape(node, 0)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
data_shape))
def _infer_Shape(self, node):
self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)