From 7495dc167a359872c92cff99f8ebb516a943f13f Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Thu, 1 Oct 2020 15:24:00 -0700 Subject: [PATCH] Symbolic shape inference: fix a bug in auto_merge when broadcasting (#5349) The bug happens when merging following shapes: input0: [1, 1, 'Min(1024, input1_dynamic_axes_3)', 'Min(1024, input1_dynamic_axes_3)'] input1: ['input1_dynamic_axes_1*input1_dynamic_axes_2', 12, 'input1_dynamic_axes_3', 'input1_dynamic_axes_3'] input2: [] The fix is to avoid broadcasting merge on input2 --- onnxruntime/python/tools/symbolic_shape_infer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 258642509c..9b75d299c0 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1202,9 +1202,11 @@ class SymbolicShapeInference: for idx in range(len(out_shape)): if out_shape[idx] is not None: continue - dim_idx = [len(s) - len(out_shape) + idx for s in shapes] - assert all([d >= 0 for d in dim_idx]) - self._add_suggested_merge([s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx)]) + # note that the broadcasting rule aligns from right to left + # if a tensor has a lower rank, it would automatically broadcast and need no merge + dim_idx = [len(s) - len(out_shape) + idx for s in shapes if len(s) >= len(out_shape) - idx] + if len(dim_idx) > 0: + self._add_suggested_merge([s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx)]) self.run_ = True else: self.run_ = False