From 69e2d319edffb64014abbc4288e56b79712d0ab1 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 17 Mar 2022 13:47:02 +0800 Subject: [PATCH] [FIX] symbolic shape infer error with onnx-1.11.0 (#10674) * [FIX] symbolic shape infer error with onnx-1.11.0 * [FIX] consider inputs name contains 'unk__' * [TEST] enable gpt2 test * [FIX] gpt2_megatron_opt.onnx graph --- .../python/tools/symbolic_shape_infer.py | 27 +++++++++++++++--- ...untime_test_python_symbolic_shape_infer.py | 10 +------ .../test_data/models/gpt2_megatron_opt.onnx | Bin 13404 -> 8901 bytes 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 0d2e7feee7..889adf0c45 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1803,6 +1803,21 @@ class SymbolicShapeInference: vi = self.known_vi_[node.output[output_index]] vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape)) + def _is_none_dim(self, dim_value): + if type(dim_value) != str: + return False + if "unk__" not in dim_value: + return False + if dim_value in self.symbolic_dims_.keys(): + return False + return True + + def _is_shape_contains_none_dim(self, out_shape): + for out in out_shape: + if self._is_none_dim(out): + return out + return None + def _infer_impl(self, start_sympy_data=None): self.sympy_data_ = start_sympy_data or {} self.out_mp_.graph.ClearField('value_info') @@ -1956,7 +1971,8 @@ class SymbolicShapeInference: if node.output[i_o] in self.sympy_data_: logger.debug(' Sympy Data: ' + str(self.sympy_data_[node.output[i_o]])) - if (out_shape is not None and None in out_shape) or out_type_undefined: + # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain + if (out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape))) or out_type_undefined: if self.auto_merge_: if node.op_type in [ 'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', 'MatMulInteger16', 'Concat', @@ -1964,8 +1980,11 @@ class SymbolicShapeInference: ]: shapes = [self._get_shape(node, i) for i in range(len(node.input))] if node.op_type in ['MatMul', 'MatMulInteger', 'MatMulInteger16']: - if None in out_shape: - idx = out_shape.index(None) + if None in out_shape or self._is_shape_contains_none_dim(out_shape): + if None in out_shape: + idx = out_shape.index(None) + else: + idx = out_shape.index(self._is_shape_contains_none_dim(out_shape)) dim_idx = [len(s) - len(out_shape) + idx for s in shapes] # only support auto merge for MatMul for dim < rank-2 when rank > 2 assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 @@ -1978,7 +1997,7 @@ class SymbolicShapeInference: if shapes: for idx in range(len(out_shape)): - if out_shape[idx] is not None: + if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]): continue # note that the broadcasting rule aligns from right to left # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index 263dead104..5abd6fdcff 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -25,21 +25,13 @@ def unique_element(lst): class TestSymbolicShapeInference(unittest.TestCase): def test_symbolic_shape_infer(self): - # skip these tests before this issue is fixed: - # https://github.com/microsoft/onnxruntime/issues/10761 - test_skip_due_to_onnx_1_11_shape_inference_change = ["GPT2", "GPT2_LM_HEAD", "test_GPT2"] - + cwd = os.getcwd() test_model_dir = os.path.join(cwd, '..', 'models') for filename in Path(test_model_dir).rglob('*.onnx'): if filename.name.startswith('.'): continue # skip some bad model files - if len(filename.parts) > 1 and \ - filename.parts[len(filename.parts) - 2] in test_skip_due_to_onnx_1_11_shape_inference_change: - print("Skip symbolic shape inference on : " + str(filename)) - continue - print("Running symbolic shape inference on : " + str(filename)) SymbolicShapeInference.infer_shapes(in_mp=onnx.load(str(filename)), auto_merge=True, diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt.onnx index e9568e381d21bb10c455011ad3596bd480a62d6a..debd5244abce510aceaccde788fd523158ac98b7 100644 GIT binary patch delta 934 zcma)4J&)5s5RGG$$a6=CH;MA06c=;cA)|{NCvZ->fWkEhe!|McdzV$>55;S7M-w5X zL?-M|m&dQ4m~ zK-bf>s-+D7wuTH>^L73;*OW7GkFQD<{<&4;SjqEOioSphZ0U^lM(%UsJh3mx{HZ^8 z?HPS(ueli5kC>2w^mK$1x$ec<9Kq8e=jbpMyTBwKcI*6Q(ZLZTYV&~?>4TnGyjv_<$3i>@p5_h`_3QO+b^O3 delta 5521 zcmb_gO>g7I8P=yI&w4kmlHGL(7x8MlT}0Tp9DdEvC{iWcY*M#wn>Yz_3Zb>xSXf(f zML8Sv!WR7l1qAe*-g^mx-h0WhKcGMl{QS;aMyZLzVh_yc*onFL)oz422ck;(k5RK;Z z+1{R-+i%rswc+RW7TfPNT-&-)EaSET^;!;f4^VfpCe_<o#k( z6tvZ9q@F@FZ4FFYS0R>d?ba2X0VoX>g;P~AjR$2Ng;UY7q#`$s7XO3MqH1&j_--$4 za&8`q2n`+GoW_UUdWlF8mX$QkeitX>m^@0;>}ANb@Kd>DxQY2w!ZfLrW_+x09a#yG z-zq;Oup&ZI3N~q;^eVLSNw;dl?I+|m2SD=!YooA(bMoult#@+@`v|zL{BeXun!$8A zd^HU z`!^b(^N-fIR=diA4=Wb9hJVSNp-0_jg`uq+H@|nY0WtGBTRR^m={&M~4Jq*-THV9* z)A2m=-F{n*1~*u1^w~Js0+o%yE!IwGepnyYKd*P#kXh%`WIP{FCsBO%G(J8aPd-U@ zpU2}*PUdW*m>VA_C~n;9zCD_s#IuO^*Yt>m>`wZ8Iy=5Xe0@AQznDi_)-Gi2LUzEE z)}d{&!(U&-@o(bD_xhb|tTm65AvEtWS{wEnAvf0gdTk29^5QMlaBTc@DD7D7H;#`Z zr{Bbl-NRD{PhD&6@hF)`ZoftM!8&Wqr(58p!G3}zE|z%K#(VMc#Zi0^k0z1VXP4gw zFA=Fw?LYKM=|Xf4USQp4aeRI}K1(zfDdOVa1DFW8o-f^N2hB z77ZB5gqQ?@E(rYb_|u5H{pPC}us)uHHP^E`idgPdnSFDyS>N5p8m#lJwS$Y(h=UH? z;fQ#|UR`|Ay1n>l^QA+q2r{h#tj$H`0ucQ*J!BTGK(4f1M1jgf zfj|_{i0w01VNaTD@yEvP6cB(3vn{0F;s=a`>EaQLZ>?wU9)!8imcmJC4Rt(_=z%fM z$-|F_%i-6c@Y~G8FYONTM}7_O@N8B|puGkP(yjMmoll=!B$QM;*>@$xD?-pf@FgK$ z7IHEnRWda;4pS~~SOJA@?2)7F1R^*mGVu`#GaTFpE2pAbU);enU z(Fn;gNc(GgtRore@!u3G^ies0p z>*upLiP2CYDQ9${mnrNvs|kArc*GP!DGZB<3|uoJ!;nG=**}X&Peo*kkVLrZNs&(q zo`-8b#|uHB1tE7olf%1~PBrq1RHH)KrcmazKh8c6<$yvtr7BjAWhEgb!7sNEYK!U)+XU{=#ZPGqBseMv!q$0TX-jNz-`Xh7JrdV>0qqM36gfje zBF+l=Wxt_)d(D({l(%A0B*aRQJ$4K4sPG8t5za=*GqTvDi;rx~;Wl{tAgv{i@pOC? zN7BYI(ufhKT+%Wk)`9H~edSI{2^JMGm!NJ%Nf1X4VLCpo5V}gptV*a^1)ye;zKKAI z@<=6B&4!RgQNsmQjS7TBDxeCZW=RdyHrS8qri8XY$4N~bR0wyfH7K+9+c-=iIh z-U!k1ZCa4;DmYwHpWMOs(Qn{5IezU>=UL3_=`f%6jA^0)^$8bYb*NQBK^QfMa#0(%B7dRTvh?DXyuePzkA)Ic7i{ z2d>qj8imNv>G#1ynj&5UFQ*?3M8%zczW8TT?9eJ&3bs`LDsS9X><+n%a9Hfjf3_}FKgZDIMuwKMj29W>t+4B znx-_b!X7`pX}|QcCSEB^WT00mD=WXXU!n>FkDMK~U7x9+HS^?`DA2gjyP@(+G*r5u zSDI1+qB`uLNNc~uyQ69g$S?7frts#c{Q>W16~Fx9;!o`dJG6?Hf-U8jewMTM2*hDK zi-Q-p?>-%m5?X{1P@d=O==pMI?Qff}zo>qpqMEYb&EKL2#Y3^5zxmcKAIg3D0Z@Iu K{QAz{Z~PZ;b8oZ&