Symbolic shape inference: fix a bug in auto_merge when broadcasting (#5349)

The bug happens when merging following shapes:

input0: [1, 1, 'Min(1024, input1_dynamic_axes_3)', 'Min(1024, input1_dynamic_axes_3)']
input1: ['input1_dynamic_axes_1*input1_dynamic_axes_2', 12, 'input1_dynamic_axes_3', 'input1_dynamic_axes_3']
input2: []

The fix is to avoid broadcasting merge on input2
This commit is contained in:
KeDengMS 2020-10-01 15:24:00 -07:00 committed by GitHub
parent caed6c264c
commit 7495dc167a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1202,9 +1202,11 @@ class SymbolicShapeInference:
for idx in range(len(out_shape)):
if out_shape[idx] is not None:
continue
dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
assert all([d >= 0 for d in dim_idx])
self._add_suggested_merge([s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx)])
# note that the broadcasting rule aligns from right to left
# if a tensor has a lower rank, it would automatically broadcast and need no merge
dim_idx = [len(s) - len(out_shape) + idx for s in shapes if len(s) >= len(out_shape) - idx]
if len(dim_idx) > 0:
self._add_suggested_merge([s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx)])
self.run_ = True
else:
self.run_ = False