mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
[FIX] symbolic shape infer error with onnx-1.11.0 (#10674)
* [FIX] symbolic shape infer error with onnx-1.11.0 * [FIX] consider inputs name contains 'unk__' * [TEST] enable gpt2 test * [FIX] gpt2_megatron_opt.onnx graph
This commit is contained in:
parent
b713855a98
commit
69e2d319ed
3 changed files with 24 additions and 13 deletions
|
|
@ -1803,6 +1803,21 @@ class SymbolicShapeInference:
|
|||
vi = self.known_vi_[node.output[output_index]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape))
|
||||
|
||||
def _is_none_dim(self, dim_value):
|
||||
if type(dim_value) != str:
|
||||
return False
|
||||
if "unk__" not in dim_value:
|
||||
return False
|
||||
if dim_value in self.symbolic_dims_.keys():
|
||||
return False
|
||||
return True
|
||||
|
||||
def _is_shape_contains_none_dim(self, out_shape):
|
||||
for out in out_shape:
|
||||
if self._is_none_dim(out):
|
||||
return out
|
||||
return None
|
||||
|
||||
def _infer_impl(self, start_sympy_data=None):
|
||||
self.sympy_data_ = start_sympy_data or {}
|
||||
self.out_mp_.graph.ClearField('value_info')
|
||||
|
|
@ -1956,7 +1971,8 @@ class SymbolicShapeInference:
|
|||
if node.output[i_o] in self.sympy_data_:
|
||||
logger.debug(' Sympy Data: ' + str(self.sympy_data_[node.output[i_o]]))
|
||||
|
||||
if (out_shape is not None and None in out_shape) or out_type_undefined:
|
||||
# onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain
|
||||
if (out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape))) or out_type_undefined:
|
||||
if self.auto_merge_:
|
||||
if node.op_type in [
|
||||
'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', 'MatMulInteger16', 'Concat',
|
||||
|
|
@ -1964,8 +1980,11 @@ class SymbolicShapeInference:
|
|||
]:
|
||||
shapes = [self._get_shape(node, i) for i in range(len(node.input))]
|
||||
if node.op_type in ['MatMul', 'MatMulInteger', 'MatMulInteger16']:
|
||||
if None in out_shape:
|
||||
idx = out_shape.index(None)
|
||||
if None in out_shape or self._is_shape_contains_none_dim(out_shape):
|
||||
if None in out_shape:
|
||||
idx = out_shape.index(None)
|
||||
else:
|
||||
idx = out_shape.index(self._is_shape_contains_none_dim(out_shape))
|
||||
dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
|
||||
# only support auto merge for MatMul for dim < rank-2 when rank > 2
|
||||
assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2
|
||||
|
|
@ -1978,7 +1997,7 @@ class SymbolicShapeInference:
|
|||
|
||||
if shapes:
|
||||
for idx in range(len(out_shape)):
|
||||
if out_shape[idx] is not None:
|
||||
if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]):
|
||||
continue
|
||||
# note that the broadcasting rule aligns from right to left
|
||||
# if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge
|
||||
|
|
|
|||
|
|
@ -25,21 +25,13 @@ def unique_element(lst):
|
|||
|
||||
class TestSymbolicShapeInference(unittest.TestCase):
|
||||
def test_symbolic_shape_infer(self):
|
||||
# skip these tests before this issue is fixed:
|
||||
# https://github.com/microsoft/onnxruntime/issues/10761
|
||||
test_skip_due_to_onnx_1_11_shape_inference_change = ["GPT2", "GPT2_LM_HEAD", "test_GPT2"]
|
||||
|
||||
|
||||
cwd = os.getcwd()
|
||||
test_model_dir = os.path.join(cwd, '..', 'models')
|
||||
for filename in Path(test_model_dir).rglob('*.onnx'):
|
||||
if filename.name.startswith('.'):
|
||||
continue # skip some bad model files
|
||||
|
||||
if len(filename.parts) > 1 and \
|
||||
filename.parts[len(filename.parts) - 2] in test_skip_due_to_onnx_1_11_shape_inference_change:
|
||||
print("Skip symbolic shape inference on : " + str(filename))
|
||||
continue
|
||||
|
||||
print("Running symbolic shape inference on : " + str(filename))
|
||||
SymbolicShapeInference.infer_shapes(in_mp=onnx.load(str(filename)),
|
||||
auto_merge=True,
|
||||
|
|
|
|||
Binary file not shown.
Loading…
Reference in a new issue