mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Symbolic shape inference: fix a bug in auto_merge when broadcasting (#5349)
The bug happens when merging following shapes: input0: [1, 1, 'Min(1024, input1_dynamic_axes_3)', 'Min(1024, input1_dynamic_axes_3)'] input1: ['input1_dynamic_axes_1*input1_dynamic_axes_2', 12, 'input1_dynamic_axes_3', 'input1_dynamic_axes_3'] input2: [] The fix is to avoid broadcasting merge on input2
This commit is contained in:
parent
caed6c264c
commit
7495dc167a
1 changed files with 5 additions and 3 deletions
|
|
@ -1202,9 +1202,11 @@ class SymbolicShapeInference:
|
|||
for idx in range(len(out_shape)):
|
||||
if out_shape[idx] is not None:
|
||||
continue
|
||||
dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
|
||||
assert all([d >= 0 for d in dim_idx])
|
||||
self._add_suggested_merge([s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx)])
|
||||
# note that the broadcasting rule aligns from right to left
|
||||
# if a tensor has a lower rank, it would automatically broadcast and need no merge
|
||||
dim_idx = [len(s) - len(out_shape) + idx for s in shapes if len(s) >= len(out_shape) - idx]
|
||||
if len(dim_idx) > 0:
|
||||
self._add_suggested_merge([s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx)])
|
||||
self.run_ = True
|
||||
else:
|
||||
self.run_ = False
|
||||
|
|
|
|||
Loading…
Reference in a new issue