diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 79be553f74..148617e740 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -689,7 +689,7 @@ class SymbolicShapeInference: get_shape_from_sympy_shape(sympy_shape))) def _infer_Expand(self, node): - expand_to_shape = self._try_get_value(node, 1) + expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) if expand_to_shape is not None: # new_shape's dim can come from shape value self._update_computed_dims(expand_to_shape)