[Symbolic shape infer] fix scalar shape in Expand (#7285)

This commit is contained in:
KeDengMS 2021-04-08 10:26:28 -07:00 committed by GitHub
parent bc6ef809bb
commit 0d49e53985
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -689,7 +689,7 @@ class SymbolicShapeInference:
get_shape_from_sympy_shape(sympy_shape)))
def _infer_Expand(self, node):
expand_to_shape = self._try_get_value(node, 1)
expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True)
if expand_to_shape is not None:
# new_shape's dim can come from shape value
self._update_computed_dims(expand_to_shape)