diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 0d2e7feee7..889adf0c45 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1803,6 +1803,21 @@ class SymbolicShapeInference: vi = self.known_vi_[node.output[output_index]] vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape)) + def _is_none_dim(self, dim_value): + if type(dim_value) != str: + return False + if "unk__" not in dim_value: + return False + if dim_value in self.symbolic_dims_.keys(): + return False + return True + + def _is_shape_contains_none_dim(self, out_shape): + for out in out_shape: + if self._is_none_dim(out): + return out + return None + def _infer_impl(self, start_sympy_data=None): self.sympy_data_ = start_sympy_data or {} self.out_mp_.graph.ClearField('value_info') @@ -1956,7 +1971,8 @@ class SymbolicShapeInference: if node.output[i_o] in self.sympy_data_: logger.debug(' Sympy Data: ' + str(self.sympy_data_[node.output[i_o]])) - if (out_shape is not None and None in out_shape) or out_type_undefined: + # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain + if (out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape))) or out_type_undefined: if self.auto_merge_: if node.op_type in [ 'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', 'MatMulInteger16', 'Concat', @@ -1964,8 +1980,11 @@ class SymbolicShapeInference: ]: shapes = [self._get_shape(node, i) for i in range(len(node.input))] if node.op_type in ['MatMul', 'MatMulInteger', 'MatMulInteger16']: - if None in out_shape: - idx = out_shape.index(None) + if None in out_shape or self._is_shape_contains_none_dim(out_shape): + if None in out_shape: + idx = out_shape.index(None) + else: + idx = out_shape.index(self._is_shape_contains_none_dim(out_shape)) dim_idx = [len(s) - len(out_shape) + idx for s in shapes] # only support auto merge for MatMul for dim < rank-2 when rank > 2 assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 @@ -1978,7 +1997,7 @@ class SymbolicShapeInference: if shapes: for idx in range(len(out_shape)): - if out_shape[idx] is not None: + if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]): continue # note that the broadcasting rule aligns from right to left # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index 263dead104..5abd6fdcff 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -25,21 +25,13 @@ def unique_element(lst): class TestSymbolicShapeInference(unittest.TestCase): def test_symbolic_shape_infer(self): - # skip these tests before this issue is fixed: - # https://github.com/microsoft/onnxruntime/issues/10761 - test_skip_due_to_onnx_1_11_shape_inference_change = ["GPT2", "GPT2_LM_HEAD", "test_GPT2"] - + cwd = os.getcwd() test_model_dir = os.path.join(cwd, '..', 'models') for filename in Path(test_model_dir).rglob('*.onnx'): if filename.name.startswith('.'): continue # skip some bad model files - if len(filename.parts) > 1 and \ - filename.parts[len(filename.parts) - 2] in test_skip_due_to_onnx_1_11_shape_inference_change: - print("Skip symbolic shape inference on : " + str(filename)) - continue - print("Running symbolic shape inference on : " + str(filename)) SymbolicShapeInference.infer_shapes(in_mp=onnx.load(str(filename)), auto_merge=True, diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt.onnx index e9568e381d..debd5244ab 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt.onnx differ