From fc2f36c6083238e2a4b43f0f25414b5abb079f0c Mon Sep 17 00:00:00 2001 From: ashbhandare Date: Thu, 6 Aug 2020 14:39:33 -0700 Subject: [PATCH] Shape independent gradient builder for Concat (#4675) * Add gradient for ConcatTraining * Graph rewriter changes for concat * Add generated onnx graph, minor fixes * Revert unintended change * Fix for MaxPoolGradTest * Fix UT * Review comments, windows tests * Review comments --- .../testdata/transform/concat_graph_gen.py | 55 +++++++ .../testdata/transform/concat_trainable.onnx | Bin 0 -> 32734 bytes .../core/graph/gradient_builder.cc | 45 ++++++ .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../core/optimizer/concat_replacement.cc | 45 ++++++ .../core/optimizer/concat_replacement.h | 32 ++++ .../core/optimizer/graph_transformer_utils.cc | 5 + .../test/gradient/gradient_checker.cc | 14 +- .../test/gradient/gradient_ops_test.cc | 150 +++++++++++------- .../test/graph/gradient_graph_builder_test.cc | 23 +++ .../test/optimizer/graph_transform_test.cc | 25 ++- 12 files changed, 331 insertions(+), 65 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/concat_graph_gen.py create mode 100644 onnxruntime/test/testdata/transform/concat_trainable.onnx create mode 100644 orttraining/orttraining/core/optimizer/concat_replacement.cc create mode 100644 orttraining/orttraining/core/optimizer/concat_replacement.h diff --git a/onnxruntime/test/testdata/transform/concat_graph_gen.py b/onnxruntime/test/testdata/transform/concat_graph_gen.py new file mode 100644 index 0000000000..599eeec221 --- /dev/null +++ b/onnxruntime/test/testdata/transform/concat_graph_gen.py @@ -0,0 +1,55 @@ +import onnx +from onnx import helper +from onnx import TensorProto +import numpy as np + +def GenerateModel(model_name): + nodes = [ + helper.make_node("Gather", ["embed_weights","input_1"], ["gather_out"], "gather"), + + helper.make_node("Add", ["gather_out", "add_q_weight"], ["add_q_out"], "add_q"), + helper.make_node("Add", ["gather_out", "add_k_weight"], ["add_k_out"], "add_k"), + helper.make_node("Add", ["gather_out", "add_v_weight"], ["add_v_out"], "add_v"), + + helper.make_node("Concat", ["add_q_out", "add_k_out", "add_v_out"], + ["concat_out"], "concat", axis=0), + + helper.make_node("Add", ["add_qkv_weight", "concat_out"], ["add_out"], "add"), + helper.make_node("ReduceSum",["add_out"],["predictions"],"reduce_sum_1", axes=[0], keepdims=1), + ] + + embed_weights = np.random.uniform(-1,1,8000).tolist() + + add_q_weight = [-0.23681640625, -0.16552734375, 0.2191162109375, -0.1756591796875, + -0.03460693359375, -0.05316162109375, -0.336181640625, -0.253662109375] + + add_k_weight = [0.0246734619140625, 0.011993408203125, 0.0178375244140625, 0.00998687744140625, + 0.0255126953125, 0.076416015625, -0.040771484375, 0.0107879638671875] + + add_v_weight = [-0.005893707275390625, -0.00916290283203125, 0.04541015625, 0.0159454345703125, + -0.0029163360595703125, -0.03472900390625, 0.0535888671875, 0.0091094970703125] + + initializers = [ # initializers + helper.make_tensor('embed_weights', TensorProto.FLOAT, [1000, 8], embed_weights), + helper.make_tensor('add_q_weight', TensorProto.FLOAT, [8], add_q_weight), + helper.make_tensor('add_k_weight', TensorProto.FLOAT, [8], add_k_weight), + helper.make_tensor('add_v_weight', TensorProto.FLOAT, [8], add_v_weight), + helper.make_tensor('add_qkv_weight', TensorProto.FLOAT, [1], [1.0]), + ] + + graph = helper.make_graph( + nodes, + "ConcatThreeInputs", #name + [ # inputs + helper.make_tensor_value_info('input_1', TensorProto.INT64, ['batch', 'seq_len']) + ], + [ # outputs + helper.make_tensor_value_info('predictions', TensorProto.FLOAT, [1,1,8]), + ], + initializers) + + model = helper.make_model(graph) + onnx.save(model, model_name) + +GenerateModel('concat_trainable.onnx') + diff --git a/onnxruntime/test/testdata/transform/concat_trainable.onnx b/onnxruntime/test/testdata/transform/concat_trainable.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1f63c242be566ce690d1f58096dffdbc741a8902 GIT binary patch literal 32734 zcmZ^~cU+J0_cz|&rKu89G$bh!>V2IfX;50)QyNM%w3U#(vsW@AdxQvG=VTS&zFj_bJIUrHAgkg#L43m6JqlcQzj&5 zXQ?Ttq-Eq}$2#b$O^C}*PMi{(o|D~8>A$wQlIMT+sX3_qA4kKGymT!=&JsA z>wnx7|6`iVyCx+3f4QXoUoNTtiS!?r)c;?X-2clZ_kX+O{?}!++W!v!U-Lhg{jd3- z({i-I=_J6qQD*w;1 z|9MDxsqx?7W7PiFPF?kkDTxUw@!2WqX<6Mgq;^hxVr*8<SrZMvwbq$Vb2B&1BvGLn_X*VXy&O$8-SNlYB|KOZH_X2UO8m6ysYD%!H}HLUw0cr_k}!2zS; zRdOZhrp%|iC!dnuv-vdn`AS-3rb*8e@55uKP;swA4;rE1jp|(s_?L|*`W_#NZ`b={ zS?mPf@SqrHp2(ongNrz~MFo>TdSQ$tmxr32glzjb9N)Z-ZbbcrGi3^BZ0dsB9R~Bm z7vCUoeIoC8(?MT0Rn>kqdoIX)`z#hOSK#}bhjV1ADMppm(UOdz7_@aPkDdJv;)J_& z$SxmKtrWPl-hkttW#PE73h3nXk1`g_r3dkI!D-rVT3ytOkFzXZ6NX~rK3V8|HWI>8 z9!V}lN3p`bV0>G;8X6khB=4G6(Wc`meEa=7Xc|wjV#iLpJ=LB{E;PdJIdj46OFlOQ zGPFW}?pHVsGq#q{gscfHv#&+?v3U$Hn6w7chfXHTDU?HYf!JR`#O(5+>qd>Mz54)XG z$HC!+XqKJN*4OiJnoK>c`x1!#Q%Az}&du~~_+JWoSObRMI_PP)1xC8Kv8B&*Nuq&_ zWVhWIeluac_;}U;EZ?PymM1?5*ZYm(GS#j0C~PiNKVJw6FCy8kpoZjBCg6QnW!9f$ zgSJ0zz?A{X9I<>;Z9umiE?1f^6!-s03wBIk|IjFCC>$dBI4~RTT$>3u)|=u^%|>cj zA;(&A#(Zn6B`f>2!^z=ks6L{JEXv$**tOpL%B8S&#sGC%^LYfXOK5}Kj4B#^FqyAZ zc%#+QaDMf~A7>_abqBiH}^H;Z0wEuFNS1!-<+SeZD&<=l`vh`;^Wap*GZa z{ZVi#>J(n5e}m6&v2cXLZ+hXH^Sh`_Q;k!9 z_T(wa&tcStA7paD5_7WC@b+DQF-&taZO$@7qd#tZcGqyoeeI9G#%u;-hfM6Jv{jrM zkUN`Id3}lYB?pminkX-=Xc?`=XXPo51uHNsG*ae>k? zJ((MlllZ>gWU#)L#$TQ)a$$)T)|~Xg5&P2c)WcF5@@k`~>}<->H-ACQTum$vu*1cn zU4+#?AJY1qPmq**x>h@V6D^tOjFTf=)qUwv!BZDZ!j5OIF%C z5f4w2a6gwo?3|ay4fPG8|NdYaIJ75To@oKw6emHzvuw;gZHpSmqAAr%omaPXrfonsw&+aCouy{Z^zD}>{gnfCm5|6W+tEtO+V5v|zh#J#ih;Zk)5T7Ng- zo2L&`!MbJ|P}=}c&y9rJjf-IJ8(pZ;*h%IK`thzpQ~teaDEIZqW^a#Mu%>AR*ng-1 z^CPS2{^*NvyU%ipK6MAeI^yY9LL@f7QWnFaM4036%KwUkBy#s6$R|x6mxXqKcH(-t znQ{Uws}#BVX%Q>U`2?$aMbOaNSiET3m7h6A;m!x4xS>xtUp!s{ch=5;Lo3tZQC}mp zJ2eWIz*tJEw&zrt!T9{+SK-2w@#t80Mu>=dA?aLJgj0<@aH{-!vCm~0KL0cn`{)g0 z#akNqeT^F`rlexgLJc9!^$N|o5rkEPT&TFDjOP3X-m`BaYMuGzvTodJ2+3E+l{Gs^ z@zq{fFgO^0uUktN`FB8fx&{@@G~lj#7DLL)IvQQ_gydA?@wty8z7LG%$z8t)#n!gG zZm~B0E8R*>#(qO?G>zcD0bbnCZY}L|h~!hJ`}2xEZ8TEXl@%8D5_>IQOGR=?Sl^IL zfn(fw;Wgh{lOak%MUP`NJL!QC=#`EO`s=gpUck(1e>!9o!jq3)qVBO<;IX1QoAiA{ zi3(}F%(NGFAG=u;qSUe9&RekEJPW6;Q^48XFV#Yt0>}6TKzl_1`41S0K~}SA)>7lz z`nq29<@tM3EzIPpH6bu_`#tJ*bSJcV?+_v~!H`2P z%q(%ifrD@=IFm;>_vYd!)6iv=Ip-{?6rFXgu#dgG_%v4$%i~}P#!*oq!ShIX3 zU0L{&$_DmE<3hls@>1bqb1&Rv^bNLDN1{?&DGhuwTj+g30}Fe}HxC;{ zGJON#k5(~W z5d8N1rZCKT0tTA*;P94OST*JfUH`cV${su7!dgRkyWNH3+UJ8w+&)OEw8q~LMqn=N zqwd#K(YyZOsg&&rd`NdFXOz~!t;5aYyhH7fkwC)tmS@yU-y4@NE#$N(a{SVK9Ht~5 zfVMSToJ-~*Bo23@@3XE7J&m0u!QoCQ;}%V(dxCk{@OIeyWf|R^tjsnMTZHvL^f0vX z0le>f4}R$Xgz%8}a8}<6kG@Oi0xAAjv@98}@3})8bSDZ^Cf*13FXksBefZ+WT)q^Z z4pmJDVN}Lw9MY=`JgS_C14my5mx4IFF)5jU81~}#&*y`jPc`_N>G0fZ7WiqA4V>C+ zgqz3Yuw9uV|DDl`U;DKS#&e5!PP-QB?A%S-&q_(&+Y6nm9Z^;lv2t8bjJ??jzNr=P zZC4ThS(=7lI`a78x7FeS`!wFv7J|LDWMZdX5Qp8kO!MEXajA*)d{bRqbvYG}Sciak z@t%-vyi-sfkc#&TT=CEZZJbf-iFBnqp3_|f%L5(RMzG`BiZS@yybZp6GK2NKbU1Iu z6ZN1QcnEcP;91=_OHu&qs_c5$sIcmI}*1FtI)9P3VEM>4m{XyOgeCb3#?D)R*= zF0u8&>8r-F-I>X#W2=KZf9w`74W7tb1CuziIg0hCmO=DHJM8_{g8vu@95OHpHgEN# zTnkkiZW4m}d%dJfLO<+QvqS28iO$X&Bxtv$!ijMq9B?@tpB#4q{l6<H5S-h!N4l?JH@!CB#-moP>*cN&bltR`+fW#W_ zZ*xbBX;ECdXC*Z&4*>fagE0Lx(zIpog*@3Nsx{mJVZSfKT*KjbZ{h^uL5Me>(;Fk5 zLUZ|}#10oZTd>cC;rQs&Abhf28!wI3g|k1d)7QGwpqeuZ#ZBY*yGIV6{_BoQuj%u~ zU6(*?9mox4^QgtB8QQ!AA!KeJ_-Qg3-~Kbl1BYC(Oi#cA`wo+uNl$DTn1oXvo`kBs z2gu277`ChQWQ&iAsJ5VkHr!q=9;->^itG&h5ikg^K3Wa?ifUkl%}01*p3I&_aljVS zdCKGzIOyhyzn^OIx%<_^C6@p}Z&ob-KJE*HUysD9pY~|^=YT}wunTsb??SSgZdhI< z$Lb5Db=P`J4t;CD2J%C3*{-oTw6zF6MH%srd%<{2`6{hX%Hu(gKhv7UmK;bY>3Cs( zG5nJunvA|!JAP!RWb~2)FnQ)Y;kTADx}CiRP1_zphu>0A&auVs<6Tj<`z=~EcQ5UK zVamhZ_lOraY!xj|OvhX8C!jTGFgx5aXU~=x{O$nwyK*{&clk#fPcNgLs|+x3X%4T- z-AIw#3rlL1uz0RJRW5i1gV6~c{{psX@P+s9MSc$8SV&*KvxeHzI$B`cAQVgVn>00)}{0NCH=)~ z&W?Csle|mf?g?mr@*i}Y_7`0171-W-zhG-0EdJZqk6VOL+}wv* zbJBx&k38;GtVVV%A*?CJp|)=Tr|ek*We@j~wsJa0_vy~|J?=qCj52LlriQ&uyqR40 z(6C**DErnBd^;gfqW#v1(xSY;tThjQmEGi(&L<&f~Q@;Y27Jw(x2|H#F&6~gO8gKSI@Di%e=uv>&W%7{da7LGQa1VK0Yc?sBn_T>bcQPs1NY^ z<02doJxh$)eUG+Eak7w!;OC9}GAICx&t%}7T0D4b@emQk`(78aNnaD7of%5|~9 zQ%ij5hiyL|I-)o4IdMdgX|lnIR^~L%s#Pf59*wQzzSUZ12V&+lkm9}?S~^S*d)@Kl z8KZk~~)M@Ox@H(+HlpCSb|qf5?jb0?>45xa0jxJI8rw<&`Ta+KmJA<)E_R9N>gvc5 z^9B5}Ar_zFbJ9OM7IU=#cmDLHxOzi&KIKBPUE{E9|3Nx>LJiK3TMp?-rSz>*pPhE} zgh4Kr)Kcb&=@D)`Ge&||XBN?VQ&*M=)Zrxu8|i(e8>Vm5;g?u8 zO7c291DbD-V8fP4+@Ybv+suYgeMMF6k)wV*`K&F8GL}4G%VFX1*lzqld9K89cqA*# z(m};9opf?-Ay?Ze;q2xXcsvwwy>3@rWUhseBZuJ0Lv6z1@qjOmjN&KT*V75lb+pN9 z0Lt%p20e|u@#W>$P`lKNKb}_R)}jP#rS0VS*p(CPcF~2-c=Y&w3x1ybNNb+}2TW0YegQ?#oeZzH-k02%vWI;v=W- z)32es$TDRIj9OmAttt*^-@iS(s$a@+gI#-2GuV&#KkHmlsZu?y?Wm`^6+uE$@Z~KH9>Y-x_pTp#;>2?h{9p z>tn8}25V-i)Ajqd7}{qPX8V^>h`|ra(6-_CUI)RlXbx@rWsU29qzDU^d85)p9aQ#d zCWT6WJSNkhpD(k=J0p5X4r$2Y`Z?p+E&c>~Zh8s}I%dM%Y(2JU(?%o5Vc>V~FUelk z<-O@6IB{ztJM@mkVcKD+_Em#7eolwRkTfj6)<*xfIAV99CnqNEf=hu7U}k5>Z~Uj@ zr|;v$a}8;HdFn&(+<%r1xyAGAUAnYRQHO(Sv$1aJE65+B%c`HvB#)0Cq@hcSxkuGn zF?WX>n;x`4jl4Kq-)M!O+wM`ZK|ECXkH^Tm0-W^y94J@Dp>0Dvo^$PicE2_Ych+TN z*WYtt``RuvB6M+WaNHPrw8;-|d%q&>9}}?Cv>q(OzYA?cN@(OVTekA(j;`}dLH6}4 z;q$BoWcE*!FOJy(BT=32R>;t@Z7ICt+5(UZox=UH;;>@Jbe^d@o<{|Hu#&S3Zm!W5 zHqO?;oOwr~zTB82rB&1WuVPE~lai}7I!vRLS+?*rZS6Xa4Lat5 z?~KtrQocZ}f0@NwdfgH=>0q zP@^N9d*%P2yL~sqN58``?!{qR**TEKy44U(+Hn5FDd>?i9Y3sZCF`r{?4=Tk*YgBc zSW+phHZtO0@72j8N0ncUQ^a{)v}(8gGlk7#Q=y|FlMi_h!f_@3cw=oKrPyvK`{)&P zV5cYi-Dv^0B${adQHRZ!8L`*!FqB+35YE@c^Y7xd)O_4jRBz7b-HU(I{vKL9wtAuz z(<_L6pNGTSTXH5py2DkpOatH6epoS?Z`U`%^ph91&Ta+!EzPj$(HJgyasz%;rSgS7jy!u_ zC?1qwN8XpGazjljDrO!MRJRED+0YvG>zuK?(^h!7c?hR&Qb9Z81lZ=VPk25rfCFn9 zNdDqvzA>f@F8BuH$q5DAvbdhcZ2be3Cduf2BNnrd{SiO?87Er0#qzcHS3pKX1Jw^@ zFjb7>f7$M$%QgusjvPvBN-{CCcQEFjh{bZ9?yMl4U#!eG&?B`}KC5ycGJbvp7^ozk z2-io$;nO&8eHbit{zI2f8Dn|xWnd;V9J|)v1hthRtlAVsar;K_{I9WeeVQ-EHwNOy zIrcoMsSR>uBXCw_m899!kH6LNXB07F%^{d9B8CI3s@yHFmf#0}eq_2U^)#(eeWLV{B*aJ*kDbi8anHT-KQnY7HN zD}y3aq`g2hoKk+oy zpI6LkJxk%!;cQF!HTh?kByp~a?|FxDWR8%@r_ z&3?*gtdI{uy=H;M90NS`e5xo+)5hTA^WmP&NWOJR1=sEmhMjjRXiMZfaCve~jP{M- zlF+Hp@F@eQK@!}T1VU^1VYu~r8fdL7!p8&S*i5sDI(eA^C&!xi3P&+!g z)rCB|#`41m85a&Hhh28L9NBje4mY31!8r>-?eck=>eWtjF5U+fZ4<@1EIpk$`l6Yoz0_TA>dP)0B8RMlr-SR`C zdAm2(Ik&*iKv#S=-VfJ!ZlZNvHSlZuX#6tE6m<#|QANd<74DAWzCYH&-SSABAJP*~ z$Q%$;9uMH9)R+2g{Xtreo;bTq8T*>LaZ9TXUOr`nmsj-X9S@K<)V!j5uElJBgR`IR)sv?|ax=!UW^(2<*MdP)i zLR6CZ48KqGggw)W*h04lYlUs1J28f=_Tw(beOh_4#$qgH|39ATJNUdnpKEvJ9KLwEMEtYs*3RZ))F!E@nZ5X z=#LqmQm%HWBip9+#D@o_p=IU{*m5QogQW!X+tR9q7`Z3#;7I}gt#20U4lA&AQ8JZ{>xx^#f094jI2An+h~_&=mw9)UbBtQl4e9DyB%C6GPE1^?X+!rK0dywb~? zjWYYvkbEV!cKeVi0m=}2gM$xtuKbNN?{`YF`9_h!wYcHG=dni z!a;lffYTuhmfe_(_Sej~?b20Y`$l~zI$@9VM%tq6+2uly3q81Kt`q|sCU9=_IS3E! z6uRCngZhG#u+g$eNNbP6H#ZDWspcglD|0 zNJHX;BgGvQx_K|S99F?|HBZPcGmnf{`eVL#7>#VMrh&($xh`CflN4OA%Rwbt+HS?g zhwc(qMxuiAXmm-I=iJi2Vt(aXQj{rV=LKebI!TdrpQvHx7+Ey_-U830JV$Vo31_8d z@&_v`o;L6Uo%yTA1+$Xrl8YJcH~K}YLnFnFCtJiN(=_qtrhByVpczKWtO14A8^Vmn zVQ?|4kOm)%1FK_|aKy1|?S~5OVZFSb35Vmd#jO|PaDkUJw=A%wJHb}5z?NyFapyp3J*O6d zIMqJ2pM%k=mIYfL}W1?$Z; z*LVTBHo5WQ5_cXuZWztZc|?W3CV^qZ0kWK~N>}@Z;JmsYq<&qC9yBLmT(&A|WKW^2 z;52MCIwE-1YonfLB^XS2A1Hu-r7UKR^3yuC$ z+56lePT`kgOUOiQzp0IxQFT-iHA`%cGQ#jBSh+gQDy)zEKD(K_5^`|69uXx}rX$`i*#E~y~Dk6RGhv+dAF>lN$&M&a%4KC*R zb437ty891m+`Qkl|PxSRi3a55XLiHWZ^rw40nED0r$Q3 zyra%T(x&6|jkCoC`&QER(OqcK>{zH6^bg){&co%Jf!OuP$5Z~FocP!uH+W*XjGX)R z;s<+mFi{}}3kT)0yUQS0wn&PtcE$61WfxqxbSh7>$mA7s3h;ba3%vWhKhBhO1<%RA zH*ci#)boA0d!!l4uiHY$Mewn2DgjTY`cpd;iTnn<-brIz^}U}7aTZ_Ujhivg4mIHCcf0VS7bhj7-X>zCw=!>bdkA}6 zZ;4yiW%ISoXW+g12sGN+3~3`K;{h#o{E@wjyt=8P(ezkO5gqW>*3oF}vq(&Ml7S0% z>hil^)7fBcH7pFjLXpy*oOy64n5>k?tjE1^$xI(UvdjjZdo_^EvTEw<&;m^-4pWx$ z9VT9Uk&^4Fhzb)}-lM&5_rz|Af+uhmh(^7Hn{O@#Xp9_xKD$Bdx+r#z? zWBEWYq>6tsEL8NNthsOC@NokiU}}qd|BU12^;wYRa*`$&7@@+0V>IObR$8{c9(q?? z5eNAB;mxIys9iZ2b#4M0)i%(~&gbM;>5M-X{-u!X9<2YOFLqI{5tA0XLR!KnP#XD~ zUjH;jNreh0g#DHLz7)iki5Yz0SQbByA4SQBitxccZ*0pEp<~rZSU+GzUCbL+Wl?xbdG88lI8C!=E2hxNT3aSyCrh zypnPZZgSjy*#$o=?TQgjU&+-+oq0tzpI%!jC>f}ro6j{ekbxi-)_(5`GxSR3c#hvu~Pz7DTtH@VRYvcL#rsQU* zg!O7+@Zw_*PXGCu3`abK^;cp^b-NMQ*txK_%3i_g^Dj8P;sC{en}}iCytwU<60e{_ z-rfC+P(4eFXPoW9!~M6A|G5+7YbM2t3d?F=J^4;W8?J$te;;nL%fXRzeQGnqE{WZH z+hWIzDfm3#s1Wz~G==QiKzBBqvqoJsJs+V1(L+snd4Nb4#wP*y*a1VoJs^c`X)tuL z3GY(whUYGsf}Un?+PpQ2=b!`Y2W}MCejkHQ)!#vixao)SHClfojvv`ZO4U6!+~e(a zan$f{R5D%;H5azf%y@S^=_w0k=dvNjt{5+?uMz)d8)96J4Vy2B1&d>`Xtg;OH%wm* zmKO~vUB8kFtep9yiZ%vK)|YZ&!Q9W;4Oi$IP+egusPdACXzUEug`Q<2!0_5_`kole^5HH#wExlCy<02D{)ic)juuuc_aUD> zc`zwsBgxyQiCR&a?3iU5Ve~>vR%Ca0?mUG8vO3IixtRET`Z-NT78>0H& z{^Ivd(@|~nM(W|M#{OPe+#$uCo$J#0s#LGHeXBn`{dNK#d|3o(6|2aqRhLdXe4~(U z1#C2X03Z4}4*Tfa@a(fCw90BN#C11>*l+rjyY&}z>ln{B8#YtHDt1g%=FIOR4Z1GNQ%hZNYNrBshV^FsJHA{qdm^8CU<{EO3fT9y1J64Y%Yz)g z*S0oIrMp9{Sj91lOZq5s^BHs2JDSX|Q`33ungy_b?@X{6E7dH-OY_}lKklnz&i8ed zICPvS&br}+p?~G@^0&V<#$HSGclabcl?>(0#@)Hd1i5{k8}AtG!v|V*p`S%64yfBH zj=8e4R#Q`li`!NTo~C(dM;oZk{~Q?VZ=?1ZvUsKRy>M-CHdY%tKw~b$2opOoC8#$p zs|u$(-GVu<{v0if?~1V-4@w5Ee+6Tt_3NCBU^o(=#}3X;G;M?&6+IkA_g72b|CNbo z5>rDtO?x1A-F$eWKN>oMJH;v9A+Xa`hgWu13(;Q<@pHsydNDT+N1wa|e>}%=Xz&tx zemjZ}{msN@Ta!`qTrYI-+D`#XYG7xi3|`hNMyt_T!o9P4IQ5}EjB?46>I~b2VF$ZF zd%IrkK=a?U@VFaSjq-(Wp#q;-r^vPkrJ5oIC0sh8kQbPj!7PUta5}p?Z_h|U#h5lw zURn+OS`H2ShjLUO6KO8B!l$L?c=x!Ji-?q^`E3&=Pe%WN`RC_TLdy~|Tl|r1S%(`F z=fIVk!<24y6qLqEdE*DkJTKcFyJclzXBN>xyO}PDZEqpIpFjVN3KT96F~P|imAxP2QT(X!T5KD*!7e-e>z@(4NCQ} zw{0Si{##A6zJy@!%0|i@8V1zTl{XxZmF)OGn1dp$+P`5B&v@cGAp=B*l|3(j=XNBWcqbf1A>I)R~jHj0mhx71t zzlGYb!96uuuCuq|0{v;26+4NaYY*mKQr+MmsfKyZyI`JMZG+!xd$Rs93tWDrhKw${ zlX6HqO;lb5ezJ!^X^;$7MJfAW%-5>k3A(N>duY z{Sd&6Dvu%vFCto$pPs;PM2#xoL)T_#rdic}f1k%L1isi!3K!WdnYEA-*DY0kK#>@@6fu;p1RopJLdCv@w2 zS?GFeI&X_e;;N`E)a$7YZ+o|&t}hMYp=GK3Y(fJFcdiK6&)kImiP7-wu?BxJn+nR+ zW`a(e395Zc!kVOQuys*7FS@^72soA<=BGG?3f00k<_9 z!HF5Gph){46-D;py^iLvqgy6^?z@HdD8^v^VLyfkdVG!gV0x1eZ;$AQr>Z5WRcI*$ zjva@o?@ai`yEnA@ekeAaJt-)=8}YjRA0^YY2JqfS7aDYVkr*>JioIT_<019SFj4k5 ztZUCj-^(g|*60r`>-LDehu)y-H?N>Q$%)q%8?f20tu$a)1P<)3&OQZaDbP0y2aR?I z%erSI>+g;gWyd9DZv<8@?jzN+DByt;`FwlyB49IfoECqdy4A~LOu01=&dtJyN@)`L zRcib*a522HSRh&f%X2z;$G)uF1M|c@@`>xf1)#H2gvcAs4HakZ?Rx@H6LERS|Btg zdvIHmJKyfn0%sO~p&strsbSC^DvPwhkhv0^G){v%*ZihGzpqh`3G~+ zO_@z;wVr>oVBBeIe)i)mh4T?uxzvL)`&=Z`M+`pFT@|m@50iw z`E+Zo1Nz$E1D7^S+&w;ow2ez(N>vv2x}CzC3ig7}yl_5v(tuU{#^TnHT#1pS6I9~P zeTwkg`V)|PR0mf*ir^jTpW&_VTFP@W=gnqQg}hA;T%{F+V=`yeKHk<%{C*>eeN`;@ zO8o<2?fQ-2r}MG4NAA$tV+%UOCq)}T)%q^@ZaO6FxElltUxV?@OGjAa*9F}!?2+m` z2H?Ieg0$!Ho?iZHg)NhZ@ok+ar0MvKHa6D5&KqG!C#CgvsxR*{PQi`|dTi`G9jvrxtXfhuVEK@puxi8|Njh-ooxf_9G>>OcD{ z$c&jlJJ(2S(s>RPQ(Q)i1SRpq)ip4pIRW=|{D2cxV{z%^aC~$03mHUhfOF1Iz%S`G z%y>SC8)IhC3m;2o63b}e-SK~a&K=}k>jYNlz+;CZ?9MeA-V!z)O6&Gj1X`bJq1_R zmB8joBRnD1U#o{t<#~T~P)^;HV+-cO#y7GYbtbX)Xh^3p&{LjOoO*EiwhnNbTmWtx zL-^baKMMSk4&IWxDC&pliRLkMfGXY~?58@<;5YkKuz^O_3ICSlGxE23Xl6}3HAL<3M zNxvJl2FGqd*}6iy+8&P!_E*BDr@GkYDZyI3(e$g!0CYFc!FwSRR_~|6x;}$(-7XL= z*Z0EvXD$dWpTpp=v?kD5Y>PT6eRxi>H}!1ogYU{eQ|VhbIC8R$*1T30L(6l(YiBYJ zE6f9pohfvA38F!!C99@{;ET>tSmm=2(xr3NN_%&#oH_s=YdWIRIcYu(^ON>!9AKTx z2uvJR#6jUI)bmIV57tJC4^f2EK29|5x-3pQ`w|+O(pj;&4LI@&`T8mG^5CbUA2bLW zPb|2{!$RJlC69ZfT(ITB6S`;;2v-ubrFdJFZ-<|toNXUzLw_5V*=`0)wIcc8^;jN! zr9aMC9Ki}3t7-A(Js|Til0zRz<3AmM>|}yj$#(qo;dqzwhi>A(<36~!B?_NBGT{`} z4j3{sPZ);sQu_B5^*rVQu9d6d`f(lXy)sA0@7f7Rj{c^&4pSc3;*0VD9xz3!+w$I^ z$6=?E@y^GMWPK+b12-38@?crMbl8^+>ca6##T^ii&j7FHDLA4qn42tRQGJsS4!Pzh z>`2{B<`1oSUzQ~wk@oa<`=oGGRVjQ=o+E16b<)ehnRq2D7xyPmhHm3-&_?QlDe+;v zVbnA(UzrbGn@5ddjj_Uqv2MJB2!`sz99K$;0A4Qq9@!oNzD^?|XcLwzw={bn8&nl4J`<8!$@Za1LbbSC49 zHKN{@LY%hG2p3fj<0A$cykPh{a;{RK$YYlve&{JEHynfWcgA31z9Hq`dj&CHrt=>^ zBX}pTQ(G{?5<$8qX!&3}gwh9aJI9B&%$b0>J?D^?-%W}rZw6Uh12#tv!h*NHT-x{z zH1vzOq@@sXh%X1s_M~naPPli;LK0~Le|oWqKJ1U?sc)8u)p5&6Sz4>5_pK1uo;X4E zb|Luqx)DCQuu;4=P9GmHISp&V#tE;VG{bD`N}8Xd%GW0T5g)t6VWgfiHdKz~cs)&) zZ8xQEv*YP=w{|$|n#Ot78k~1fia~;xfZ_hz(7S&m85y60o(ezc#&&@%&S?ZV8Y5bus+m{-ETC}z(zH0Si4;8`X_)VoVY^HQcX$9XvyTp>`_lRCq=_6KNX$Ta`~j{@c0+Z*KzduJhvW0!QN<)(p7irH zMSS&QrK3t1;+4y_`+Eq})iXHLd^l|15{8{@g$m_aR6J*~5Zk_mG%u<0N5y2=tlEz= z<2CSA|A82-JBcm&X7JN@v69R2Q+fCWeO~>#1x$9OVcxL`9JhTpRR(0@`DN+2@K_sl zyFVU3==!jB@NBp}rWvNB4TfWi);#>cJR!PmH2>4;Pn~_j`1mx2@%62=e{K%naL?r) z*<&z!!5%m|zYYwY`k}S!S&%I2$6HRPVb#6iESN>nln;8m^Ybza=_l11g)A0!4CsOU zLKD_bOy(i?BKY3PJHq3WE_{5eEj~LT;KuK_XqU1gCPt^@w6RSTkTOJkz2P2JOlze@ z%f0Yf@H8I2aU4gk`ayd}EFf>|E?Bc}BI|ZXzx^v6(_VOaW9 zmo96414E5C?hx{MUS%ydEzn0cFOG9o4vMm0pmfg=xF>@T$i*S zc1wVbVs1dxYkPKmmXFG(9dIHkan97&g8G5udAcYd7ns4y1`In zwwio&BeCuGAL_BXnYxeZA{1Re0TZ?ZPuZr=r5#Viffo)^&nOQZC!Yt)H>cqFW8W#v zzb{{W63tycSnw9Be{@5~3K#Y3fg^3DJsELPt->z}M?6%+JN4g1>+ct89qRqL-q4*Z z13tomYp%RBwL6Apwm|B=i7uaY26D}2d#V`u1;pMNZ1_-uSJs&Fq4()L_<0@W*)AcE zW@}CzvIX`ytMJD#7p!>tT*{aDu|(6Hl9ubC;_gJuIo}3n#sfJo7EXB=XG| zi+#cjFyUzeCtYx0BjeNLvNnRR^lBA<51Ygu??Z6sMp-VrFiZMRT|W8PaV%2LK@eXZ z!R-|fA$)E)>ZiKE!>)aJ;E8C-m63XQJClJ$4NLVe-%}{J_@w10X~$E7A6&ja6(*!q^d$6XT&qO zC$i+!{`WNbcOxwQ_7?o^dhohRL!4?p92acw$)iGixf22@r>hzd*cwmQw#^bPt7Gt4 z{{pJ?Tf!<(2 zIH5O!x5~J3!`7p4b&m9VE*egK`_@a!ZKhy_bAr%xWWLx#))ZZr|AUXCrsB*N73!m; z3eRQ)@+lK%>?Qk!aP1c`x@&_^dY8b_)Ef}#Z^+uw3>(5;yS&k`!uluA=}Ut=`$b3a zx=(TB;An#p4+yrKb>WM}gZZ?(GdlQ6*CYE%`=XYHcyW~vhIN*}9!IIBwpIn|Dh#ly zCP^&l+$vtkdO&}kW=Z*Nk(R#Ds8w9!4Y8}DdBk%i&^xyq>a`>JV}dCSojj7_v2Gxj$SS7S#lh{sWnsJFzFiJN7H%T zn4`juEhosiI+L2N^e5dI=~}n%1?YEcmSFbLnInqid9mai35hLY{n-`p=*0nH_jYra zx#M5K&G!>9<=qgrkn%DYl_ud4pL)rZ?mMX6IR$g>w!kEPEppNeC3%&)P&!MVW`118)FKjC z)33*P+*lp<+_mHu8B8M!tsxMc@f7z7X^>j|NIb7{tr7|0phgHGQ~S^L@Pu7D%PK3&V+>zj zT9D2#fupaqf*pcBGq@re%v>zs?uUVJbjo*>4#|X>XCJnNyjP?r4l{={X8Sq6d4W!W%NjK@NXO*r9{ilUt@7g zWJzQ7K0GI-Md@cfSpV=?>g(BW)`sBDGa&}3K~>mby)=j20xrK25p&mOKWeN{}I>OE;aWwmdGX@zMQ%$@nb-X@> zec9H+eeMa0=ZEc#$iKYNv+A*2QMh5fWBp%FJ}`_YL3Uf}+?lRcOm!+f(k zaF$mjg>}V)wxS#Sx+#JW;U%zlfG;_G4B;34nnC-P^fFJ^a;W(fLw{o3{IQ&|hfaY(&wWFEG_Goyqwa!9sTptlw2gTM`eUW6u`6{2~lE;a<60 z7fX`uYHY9n67Fc~Q50Wy#P&ah^rz)HdU{7d-b8a+Z#RILwtma3bA|AG%oVn~BMchW z4*>BYb$AzJ!Ai!SU}sz%sH@BhKDkV0du#PcyULMHx~Gs?z$o^IH-^D7CUo@08-Bw_ z31*?bm5Z0tgiyy~EWtdmt&4))yXT|nO$!P$AzXB(7H>W}fbE|p$oI@5Y-l*amJ|g+ z`sLTS;@$zad6+N3X;nVePzNSe74j?2TVkDwEG2v~Bb7o`8ZkVTeKj0GyO$5e@bwqb z?Avq-Qgs3S=cVvRq)zfq@xb4`%8t%Tq{J)3A^$hWy~>6Nx8ZI!pBpvC*xv+9x5UaJR& zx^bZV+!+!A2hp^#|FH=9VvwJlg=Kx2wAg19N$+!q16ya&4>w0h${r7MM&D!GlV;J? z`TKCRrXDz)s$;!sMfB^BGp+YpgTuF*(SH|Dq2jR(yk=w~RpM{<=yL_$ZX7_4ItQ77 zUNNbRx2K!mhp`(YlBmOEG)O#AhS~eyvK;3Sru0bAHC4Qz)6FDMc}w}$za4u;SB zrb5H6;c%&ND{^+Oq-T4d+gCP#zGz%yXYb}R>l6F&@s21mSg;#a9{Iwq_zV2A-Gy}U z;3@XUasVttA80mv#SH7ysmHhg7TvUBox*$Tqg5XL9h=XM%$y2sho%borYRz-DW$d( z=lI*gUS00b#z(#R-s?>`i+tH+SJ?nJKbEYsQq z?p43yD+X`EMMc-S<$na5TI^1LA1cu;$q2AodkYn>RAJ+GKiakO7oL+brAY=usLrra z{OE}aeD3)O`z{Ev5QR=vNxGq1EGyXV` zVyd;d4fk8oan?e9-qF>#uf~u>@^Mt!`K#HaGL;^`^P<@!(&50r?aWd75X-+A1GQ1v z^vYlnKC*OWg-aLl-i~A7YkwsR9lC=1ck?%|koy%;cP!jpP>0#Q-Apr9SX->GF~@(x z9(zeSY|xISw}*mRR>(^Bu27Adc8_Fbaqq=aN#9$}-y1=LoBNq#lR3?68VPnUJ}{B_ z8NNgEDStHb4Gzr+#zjkx@T)^AV8(`}XgTB)J~(znR8Z5%iq=$6w&gThuxJR3{`-Sn zj^v;;_Xgf?*o&3#@3K?-!+GQV@!V3MjqJp7f$b{2it!m$T#dQ#V<DUv2+-p3o4=dgc5Hu$`=7R?Rc$&X3WgE8GH+%Hk7QdMr(O*I-9hdy*6!P zdy~@GrBnsD^|P4942h$~UVm{}#Vt1SuQZ&CO`%H{2UFCZ9FR6Bhfm4Fas0soIKD^& z1FDvx)xM)xV0|4I&hTzYx|IXbLMK6p`vm$|DD;qA2FKH6dCy64Y&9@+&Ip`L}fySgEH9 zF&}5a>^^gRVPFP_-bk_QV~?_gQVq!dJB$X+%z^G>ar86A60K`nQ153cWL9k8wMS2< z=us8)cYG^83I5FtyakrTPS6u_YI$$}O0j9>B&hs00Dhp*dD9z4M~81j*9$xF^ezoZ zzc`fnW>0`+NjtD;lq)=oKEP@z1L|c}pr!kXxZ$}GdCX6s2TPuaZ+0G#eJ?LEd z7~;P;(W*LUwyx0{E}YYW)2$bo<`cx5saCXQ>sfxpr_KD=B3~@YlcKl=AphK4N^P^} z7MWXvrRQ_3S_1SiVUCbXP-6$m%c1G+J1p!T247u=@hu5OVCudIVuzjzUIn zPXSyD(14UDP0aIJC!4iJMJ##vJ;J{PPPgqT`p*~!ACEm^QWB3@^&cndso2MUv}Qx- zP7$PnxC?3*9waSh4{FNc%a$}YyGKXeR;Nt)mRp>$+>~c{$ONUx+WkLGFcBVS6nAF5#v>zEq z_bwQaL60#8r^b=zg?=ufG@QcL^sz6Sv&nMkc#<(Lr2J(z{ExdyRMLHq-;veLT|9D` zpS`&W>s4m*8w{H{qnGVydAk5cP5Fn?Z#CGVj5JtUlf_BJ$ikBkrZDGsENLxS%Z#2> z;@I17L>lA!xyP?P;K=cOeusUKn2nI7aGO%Hx804$;2GyfA6feGHC)VFhnDd>>5X#1;N62*>Y>!!rh zeG?fn7dUI<R0`JKdBxeJW>Tx>YR+`XF}xC$ zNqy}SB-^>1?<{l09CKq>d^(KuwG<)RW+$u95}2D8L*S^@59X1UN7y$Vk{oBSeRtjY z%{LE=4>e0cQRqT^cXlK#S51I-8=Cpy9;I+X-x*%Uc)%CiCRXSE6i3fL!R8Ehq0QB@ zOsjejI9!xun`JX8wQ(|~91wK>6cw2JHwP4YuJRGbv`D>j97LSIgQ~WTyu{~t5D%85 z8QW(wzlm*VUM|CGiqoj(>t(jPP>PlEs^GP0CdqgH!duPiOwzX$IzHR8u^Sijo-g}x z*8Fj>LhzmgwrNtk%}=qc7SB#B7(}aw+R_C72>6rt3ajKJz*6!Zw>=@A{yf*E6)|cw zD%cT!gw3GSDgh+k|5v=Ng2P~Djc#_n)TXXT3h^R(B3a4G-{|xEwzJ>b0vq#fF8l9eEK6_gMYk|@&{{qlCFY;PTe1VWYr|jRzY1xl)AI{&taaopXUa0s z_hFQF{{udE_{e3QkB1@i)M0DBI!wcX#cpJ8@%6Bg<2Vj5FS;AQGl?o8fvSkyh0zC{0Gp$&_< zvuQf~!T5ooqFN3&29>f{4__9z-jo#Q37vFGveeR(4NGE-s7dgg<{ng|!Su@kwvD=mqa+6WPxy8JRkRM!>Gmzbq z>BPK58M?XX8k^nd3@U@o=%=4A41W>?wcS}{)MCi>I-bFKzN%1vDu6e!jDo_7$L%p83m)wyb_oi zwHD7yD$w{?UoyGvNL$Z6VrQ2GQA(X67#aq`9Yq<@-=>qy+~zh0Z`Kg8S&Pvak1r4ZlENon$mNbhIwLIzbaWz6nw@%d+|_}CT&wR zTc%m?+l^& z*5rSv2%?6jk?2A=oVcq(O`WfC$NV~7gqTPVwQL#)?Q46>Y;n^SgATFTg1|$L1V!#YAE%(1e4K(p|q%aAZ}V| zO+V*+W)AO4Vbow_y1`r1>8`Qx^hP#Y|7r>a>91gi4oJa=_IbR|@sSj1KY`Ze#)Eg8 zBpWhhGm@7NJ((t~9SPf7pRyhM{OPgRMGPveVw&&E!Np?~yWvzy^ICfOO=DhiLq9k`o2MPu zGVl$nn=lwl6ynLaLKoI96@2K+>)7+aEPC-smX<~gW!?Ac`Qpf6h#qCj{<^91%YUun zt)%XYHz`?TU$7mtdD}p9+g)zELMffhyv-??2hnj=3%s#rHhZM;n*YAZjlwjSu(j*& z@OK^skiL*XfAsbf`>m%$X>0=gJva(8A_w8h6HfG_te53IF2L8)E4wQOiM@f3- zlzH9;cMdSXg}&7sKRt$I>yHSz9bf2t97Vxn|6A{?k#>`)c}O8C7ma|a*@Iz@sTxG(?!Xm;5=EK^Ovy~8fZg@7 zgVJGR8L%j<&b8aW+V)Xeuc+TtTXU8Q$S`l^R0&*4`%lhup3okBO=kWI^(+ul0%`ENQb+h2!w zJeA>&cG0Qa*A4hyONR!{)j+$*iL@oz5Jv8fpvCtmvBz(->Ec-vdZ4I95-%pu8ovxU z)?UWFzV{ZrW=9DglP4{>a-O9XOXBGA0d&vU7n(L1v7t6Y$=#xo1$0^CLlYsJ_RX9c zArBI+7vZz1mZT(3frwxA%(+F#mbQnnnAgS_+B%FxQdd~Fr6J9_z7Ok0`N6~UYAj*p zU-n>QIi*+2l7mw>3%V*zj;;e(UE?ye-Je5RHUcY?Gzprs*5FRFBv=&no4NSz=S$9P z=c+F|K-in9?BE@H(Q=z4GB{9#i`L}O%C$gS!%EnLExGhw+<_`{6k*a?fkB?p%wO}k z#xx4TpnCsb{2n8CHHM*h>`O2_{(cyjd(7azF0qBFXM}bArY#s;68My7>U6j%gr>e6 z2~EAi{i7vj@}^GoZoLZ)RIX+3oDE39%bHBH!zkhn!a2X5W|C z(Gr;g+Nv;!E*u^Q%hxNB`6hW<_{^7lBbC8>-zxU!k|J|GGMo$geGGrwq@bL6CEi`Q zmao}31hm^evD@5P6tQ8n?CUJ{>yjK8ByGZ~z7{-CGlKJc5e+K8pNYo_K9W-TD%S6O zot3@LVE0sp!e!@A?CbLvY+tn#l&=bcn8~gbafzb~^>%PtPm+#=9>L1uT$q@$3`uzv z9^R4-yLwC^XX|9HNa&PP^cs$m!#Os=DGRD>%)rfSHV)jo6SDWC**P{gHVABkG zZ;*k*s)YT0sWydFw8g8Wj%gywZsL3 zYxoy)?RmEh1Gd)Jf^uzyOwuY9s#aFvozMyv`aj3A-hp6W+QIZzh)Did6HdykM_o@J zX7w?jF56_%ml@UQnj}LetK4wzPL4W{on=fE@0HlN*lSwaJ(j5?^eZnNN>gtI!Phw@ ztT%8r>XZRXK5|7gQ8+UUa1KO=EJ;Y-dK6WcIkCAvouFiGI&43En$0mDizn11=)|&m zG!S^=a)JFHG;^Kks_G~>7*mCjac*SqI*>hf*uIvs&VGS3w@Zj^gu&3reKUB(y zV*V(x2Dkf6ylfoY8QseicfV)Z4r4)fTOX$x+RIiddE+QnKtc&N`&bsk9Udd(c4z-! zMwKVT=4D&?EGdSUIE zz%t>ja30^U-Hd1H4)1m^6Sy54IYmAL>~Bl3MsFc6{VE*lj^sm~MH-A+pN7G6W^mD} z3s_EyIkcCRk-Z3D&*2&TyI(U|f=xB^xw`<9fBDk7>qacJ{WMyp?!_~GR@ARx1X_QL zVP?4`d~Gu4N;7mJ)#n^){RpGy%gjh&Imby48B5NBAEP@ajAC9oVA+{==C$JvtNr>8 z9|k;P0=y-*zU2W$0;fA|ejyxQrAf`}67j-uNz!oL$?iYA$0nuj#N$PNbo^pp3A`jXfm#{sp2weW8@_(A$S@swe+IR078!=awJ`{=|ZRBSDu%-c8 z$=ETKEKSPr>BJ8+?%-nCSXRB|Gv1wjj7bU_-}>;uG^J-gs+QzIEK)HN)Ig!p}9t$VhbX`&0iNn!;{#{AE9LCn8_U< zoJ$5H*P~as2c*}jpsI2t`HsnfCytld>Lq#5Fn<<3*6@Ju&NCruQ<-pQjs_zwz&&Yq zxWkc!kR=)fJuc?d`E@V`uAN4gau0G28XK_Ic{Mjp$Y>k%NMPUR9ya2>&=-4Y1$S8_ zHd;4a=nv?`g=_NhQmT~D-7F-?#IEAole(Fn{8(zax)P0c=`kHu zH(qsyJD$$jjjMjk(Kn%UMV%T^XH^XQx^f!ql<)^oT*<0IhRxE@p>juMc&MfVreeVa zeL4!7&uZac({Xt5!!RiKvB5b9R$O1SvmIhX1yjLA8rDv`z7*_Ki)r*!3{pxmW>nx*AyF zb}!67rO)sA_7(?DX<*ZQ`dRM?BOD~;(!+&Lq=-LluH|nhr{KPst5)3Z;>~PbrZ0bbmLt@QBk=HlYgpvi1e*EXhUl#W zO!OE-^4@|DT)l_QT9yD;$GYG%VQ-nZejfjk?h35?4VIN4OUoX9<9>-EVVP(b`Wrvu z_4_X}+o8?u^3H*z6FQhU@40w=eh4_!er1^AOZtZu*)Yu*@D^A=ugeo)THs#%bTWVz zCAD%@^8e6C=t6I}KY}Xvg=5fZU4GpTe>x?w%%-1|=+VYIm|T3BJ5-Sjb06BFTeCYE z+B3N4b%d>&HVL+kTF!-@^n+TVt2Ht1IzF?Cpzno(4z+7MwY(fcPfjLKV`4IFIc)`h zU=;gsF`wiDL-Cy9U^q5di=>jap=(<(cwm2Rc)6yroxPj5&y+{UN2ofO55` z)A>4M_IY3!oRD_HAU72lKUT~nhR)~a&3V94x-y~Ta<;1QEgySSAI_{Ofh*-oBstB5 zj$AZ_=r`(gdEV7b0E)wi!=(SPD7&q4uSe;NS)-43KNLO1>&*9IJQ)Cda% z&0+ic{j5380oJJaL#O0U;q0XYGb@d8Pe}?*8b63sm<1U17K6iaKiDI1Y!4TWgfW&J z4X6$!6`daJxMMVDA2J=^L{wMSR-Cu&pq zMl+$a^qe@k~aqkd95VK_mVP+vYeLjvma!nSR5CGh#RI)O|x-xZM#h z#DBvNMia@kt(J{4Go@EE1O{pSTP{?y5Z11a!UdJWImuAa;~sm$;ZalQ_xaH@b#F2x zonOP~wGSjGnSz~?3%P};@w$gBsj*A}E;neyje)=Tdh1zKE)hg;zAR+>F8NdHrzPxL zaR;yDJdE@z7P3R;h}j?1Iji(wp`-mTYuwc+8Zae@?Va1lGAfeUhvz48&sjr~KCMVm zPotUIaVuyUr!Vw_6oABbO%gr!LbbRk8v81c8aC!|r7I6H(+jC2JHiYWZ5~f|i=;V! zk8vbQiid){5tMnvl$}*pCW)Q7Q1XH2HgB<}wFCcS4T`a3gDFsRW;s-MMwPRuHS4q!}`dpG!--H4haiYU!o0mr<| z1;eAov{3C3HuZ01`(+K_@kUR)u~i*|Q{Lltw`6!(xem2+73jgu_u@sLBFH#Lk?bcb zP+#Rk?3&oc+$NYo#Va{_A{|9OoBy$p38~N}s=+v?1ga>!$I2pGIrDWb+@8u%W+U%U z%ijvz#>P?fY{zc%^ zH}bs*0;`8ZrKok_1vY-`2X59GZHVri1<~^#b6AdQ?<++pGVe0seoo>hI(WY-S(l2w-!9T&QUf{csEA7qDXtu={-)DI+$E^Fw$U5&8~W1!eMmp)`f z61T98b#B$C8C}7UuQJj1ilc`1vm!|u*e24a;v_kGTxSWVI^{rf z=4g5Je+&5Ko`M^`6vLFH8}?1UWnZn`LnR^u@Ls~jZ0eUEYd zBkHQtS zmt23t{+X-MuSrqhtP&2#Ga})qz)!B;dy0Rx=RNn=K#HFHT)-|J)WV)9dzv?+h#t>Q zpq>9}`ILR;oc}>r(A9CkynV~rSO1U9?0XiBI+01TRvU3lz9EFwd9f>({^ADd(acX) z8*&%CL%qz&^tHtswnV8>Mu{cl&QM`S4KKt4LSCX}VJS;lA`5FPgnizlcxv$MU^erA z;t-Q}yv7cH7HApDx!9zD^7^p*!s=u<6Y6hSo;|%K( z8nC_4o3>7H0(%c3j5+XVg+bSsP{&28!e#UfrWGH{SkI+Y?1d@-YbMM!s(WFpQtZ-k8 zOALg$ItRMg(xtxi`?wr+74IAcxmjws9`oiPI=j?gg`7 zTFIRsu0?)r^4!Z3f3n(@2un9!WHyh4{_8K}$WG9r&IDaxVYfJFy!?e}uIOa;C5Jey zyuj-9KCsr52TU<#68_v1N>0|+uzGkJc;{GBRe?Ng3-O18!EvxAE|#8Xo<@hlbzW9! z$3>F;73APP5}XU%;7L~#bCcV|bBo>K=I_ldI+q5r@xfx=Wu36b_X=x*CgHVzrYJgm z9tT{$%f@%u;YHDEZnt0KU~t2S~6tn1OnF94<_jlmrC z#r#oBlo+VpZJJu?hvQ1OImmfT~*t_;Ca%~+y@NwoOK4ty?mf|LF?oUZP1 zr|}h!SbLr%MR7(HCgAFSuQafznZs$?b5{(#aGlL*b7GY{couR+V1SpYumztNvnjU? z_#p}fbk$SD)L1#}%j^(+*{Taas$_)y-CGu8nnfAPLdSH_O6Fi3O<^ZrGwaujyuMtl z6MNPRx%kWDDBx`Yta9(-k|UP0>7Hl#-U}CTUfl?=IcCSE4Tz#aa)Th%S%SQ`)uGf% zdD@$kNeY6ddv@RnHZ98$)UKp~=Q(*8=)cR$fA)9wPeq5BO|pTHRukY;VL5SHSA?~G zI{D34r;S~o_$|K9a2YFL_?c>cr%o8jpK=5<>3_9vXrkJ7E3X|ez0}_Lg`_Si0-*>V=EV>!*4BbD0}x3_l70Y9G~H^ z`LFy-bi|oL8`f1Qm^%46kezS`+_2J|BF*KdeU?W~ip!U;CY=@K~4hnvXldUGf zwUivT_rjUxr!`Z+xXgylPuzs{l@pow(-^RLTtunTcTsCw5_>=CBPvQt^Fm`0{k%4b zHnx~hA2$^YYs<;^n9w8ICIRuqd3a&3I}MAP03&N{KvD3aR%RT*fe+T=()@lbIa7)I zLu`rN3t?x^HVPVVD|f9Wh}w6|0FUk4xdHWNwAAz^mbuvy32fh#vxAuVlztrSG=PuX zxgGD_ScRkR30za{T{h3B2xG7QX7S@t|srT06#Y<u5{(rX0mY zE5||YsSCiTW(J?PanA>Lz(LM37JM^RY@AfCG$T@?o2FKBH!5bFlw{6qh31o8500ZjY;Ie7xR}Gsk_nRpWA3`06waD>zBv#(4KqK=4coXLhvrkl^Plqqa?=7TB z9Rt{yNLx%DE20A}M&zJcKx0PV!xpWZXi^ylv6VAnfTtVC%&znb=L}f%vE7*6qD39M zI+=6p8;mub!|XQCpzvq+yo~ODV)NEKWR}V!$jEXb&)!I~c}ic=a>G{^%U5B@2rX^~ zZv@4y6QQL%gxTGghn@>mX``TLmTY$hDaVCu!h`qRC7%~~zNw6kmc*0xJYOifT>(z# z3n@$J7MWA2MAmyZvi~ZS*vqBqBx!F>dSt>VwXEJA3qa#=B7#(HbzjCRLR*2oo${1ud8^(MCeuuXG^-jVXDwoe0Kzg zRih@;i4HY-q96&92giUOcHpOPW9Z5+A$wA7NPly3soguDPUKj^r!|tIoZE>Y7CQg5 zB2Qv?%LI1Dun>NRc#%WA0nJ^Q36noGu}=@>ApTA==}-Q_DufG}BPx`Qi z`yA-CF_?AdG=FB{GuAV;6VL5b<=llH$Fa#l&?98t#$Bjpr_+YQGWR0LDQIRVGt6QB zyh3t!(V+OJCg3x;3B%sr!8a?eF!zJQXkw8*eHQc?qvwNR!L9&WHarVP%o&rgPA94o>xWDf_)%9_u}Xus%!nr5sT^l&Y0$gx z&eZ>K5*)mj$SQ|CN7o(veaOX0wYfOR;9$Mf8xkj*I#%u&J^H(pS3E z$iqog`bmoP-~Gqec)78})6UeK@`AYundK?t595ZyKWxx|Q2f4?jo3Mvyt_SwXX+h_ zwrpm%6pqHm%f?gN-LYhPE)n(~@S=T7O{q&s2l`hsI?7VXWSAiYPVQ!- z9<;L3o^do^MVFGq1Gu>M4zBcNxxkna<`pUm@1f1Oe}xsAhcsda?chUNgJ6^7b@Vli zqE){Q==`>8yrq!aZtJgQgYFE(u{VX;9@?2Sb;BeGU$q`()Zzv9*p9vmxiX7CelR6r zB1-m_&@s;;f|VqTe~Pr>x{53X@0v-*)>0rl!U*cO>VT8484QX&#E!KV(5A*6TvoaZ zmdw_Hh}&;b|C=qWUd4fN;0-?IzhBIAdIaT0XNq&>TQK)VHVh6)q@PdxpfF$%yqkFf zZ_KX4H7B>Dx1J7kxqH&LXZd8czMCykEGCClLF6ebu-Sr+y5(9P^s0`Bg|#cimwoc! zPE$Ls%j96ynrTqC%#TJb2xpBsiBKeTge`nNf?d3>4t`PRJ#D6{@XlBIFmu^e?4K|b zj>J|l^#x1W+&4+Iz40~+xjmVzT+Xt0JN}}K`$6QFGpJTcg*%2q{-k+4n3!>}LE;Si zJpMoYaL*o9j>%A{*9;mh`G)>Qb1?wabed-WLP1KO5xeJci%>gQ^CC{Q5+vN|U7dx1-*cV?v7z1}AoN3sOOnyN?0Ng2? z&02*uDP!Duv~)CslrVGp@}iH$KQ;r;EPXPc8cfsWYH?+g6tsLPLsP$cW{~rPwN~E3 zCA&|vxLP4&;2yyRMqlS;r;MWk*M2gIs0Hj|pBvn)F@o4LJxuq~HvYB09hg1WV6Sfn zk&n($Dqm@iQ^!cbB1sE6{Bt}1xigQZR;JV9?uG2x$3Zl);66%EjeyG=gwB|wl2o;F zHQqfj1(r?}GJa+6xSprRVA1oPJ9l&)O6rwhaiAF-6m%}b>Q?q>#3_F3x085Tdm1_Y zJb;s>zMx|DSkhJ<2|4%7L2ue_KK#*LR;3mOA#UbmAiY+k{^%lV-ciFlZv5rf5w+g8gNcu#;T3h`>Uhn7Zh z6tZ4;CXSbqac?&KC(LNuSx`it;ggxd89DZFQ4%&Ebbw_u=VFrVYpzxEhxnFoR@r{w z40|H<14s+Mqs=~J)E_7VzV|mWh3INlxxNfqew5PtiH0P;d5{0|NQzA?<|t~StvMK}3)+d&LIkrvze1Q4(KvK$omSN75Rb1I%`l22L-sVU4FYPQRe=VmtZ2-&}+K4SDPvLX7{VeCNGt3f? z<=0QW&JG!$!*>g1AWw4=%)K&=`gD|Opv(l?@Lx0AG+{aJdWf9druXb$=U$StXQv61Rbn9j!*t;V^>xWp~G?(y*B z;x7I4q?K*DU{E|7#}26=nd1gvu*ivxQRu>1pZ9R$0psDa(qy_PCC#?&sKzsY_K0;Yp^w8dPMO0l$|f!k?}im~gM3 zkKSYpKL7jZdK;JjXE3_4#Oz#@m~D*`lP5EYZfBG@J&2b$9nDJ^??s93Iq37h&8z?a zt}<>FNz5K4k}$3oNo+C|Nl4rj{@%bzNLUJ2JBt3_tN!Q6`v3o`Gc6Lv+F}U_PflXf zT;XbHk?>aLC1%HR65d+>!{uYl73HV>-&5WboYV&^ z>MF>qYbwgC%O)kxOioiApe~=6G9w`~B|FJP@qe8yudc4BsvxPZq$s5z`9D4iK2oYm F{};@+c}@TT literal 0 HcmV?d00001 diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 5f83592553..438024873d 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -9,7 +9,9 @@ #include "onnx/defs/attr_proto_util.h" #include "onnx/defs/tensor_proto_util.h" + #include "core/framework/tensorprotoutils.h" +#include "core/providers/common.h" #include "orttraining/core/framework/distributed_run_context.h" #include "orttraining/core/graph/gradient_builder_registry.h" #include "orttraining/core/graph/graph_augmenter.h" @@ -534,6 +536,49 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) { new_attributes)}; } +IMPLEMENT_GRADIENT_BUILDER(GetConcatTrainingGradient) { + auto attributes = SrcNodeAttributes(); + ORT_ENFORCE(utils::HasInt(attributes.at("axis"))); + auto axis = attributes.at("axis").i(); + + std::vector split_attribute(GetSrcNodeInputSize()); + std::vector outputs; + bool known_shapes = true; + for (int i = 0; i < GetSrcNodeInputSize(); ++i) { + std::vector data_shape; + if (GetShape(I(i), data_shape).IsOK()) { + int64_t rank = static_cast(data_shape.size()); + int64_t axis_index = HandleNegativeAxis(axis, rank); + if (data_shape[axis_index].has_dim_value()) { + split_attribute[i] = data_shape[axis_index].dim_value(); + } else { + known_shapes = false; + } + } else { + known_shapes = false; + } + + outputs.push_back(GI(i)); + } + + std::vector new_attributes; + new_attributes.push_back(MakeAttribute("axis", axis)); + if (known_shapes) { + new_attributes.push_back(MakeAttribute("split", split_attribute)); + return std::vector{ + NodeDef("Split", + {GO(0)}, + outputs, + new_attributes)}; + } else { + return std::vector{ + NodeDef(OpDef{"SplitTraining", kMSDomain, 1}, + {GO(0), O(1)}, + outputs, + new_attributes)}; + } +} + IMPLEMENT_GRADIENT_BUILDER(GetGatherNDGradient) { auto attributes = SrcNodeAttributes(); ORT_ENFORCE(attributes.at("batch_dims").has_i()); diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 50ee1c27e9..87f007b90c 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -28,6 +28,7 @@ DECLARE_GRADIENT_BUILDER(GetReduceSumGradient) DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient) DECLARE_GRADIENT_BUILDER(GetPowGradient) DECLARE_GRADIENT_BUILDER(GetConcatGradient) +DECLARE_GRADIENT_BUILDER(GetConcatTrainingGradient) DECLARE_GRADIENT_BUILDER(GetReshapeGradient) DECLARE_GRADIENT_BUILDER(GetTransposeGradient) DECLARE_GRADIENT_BUILDER(GetPoolGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 7eb39e97d0..5a4146e92c 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -60,6 +60,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient); REGISTER_GRADIENT_BUILDER("Div", GetDivGradient); REGISTER_GRADIENT_BUILDER("Concat", GetConcatGradient); + REGISTER_GRADIENT_BUILDER("ConcatTraining", GetConcatTrainingGradient); REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient); REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient); REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient); diff --git a/orttraining/orttraining/core/optimizer/concat_replacement.cc b/orttraining/orttraining/core/optimizer/concat_replacement.cc new file mode 100644 index 0000000000..37d302765c --- /dev/null +++ b/orttraining/orttraining/core/optimizer/concat_replacement.cc @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/concat_replacement.h" + +#include "core/common/logging/logging.h" +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +Status ConcatReplacement::Apply(Graph& graph, Node& concat_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + const auto& concat_inputs = concat_node.MutableInputDefs(); + auto& concat_outputs = concat_node.MutableOutputDefs(); + + ONNX_NAMESPACE::TypeProto t; + t.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + t.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(concat_inputs.size()); + + NodeArg& ip_shape_op = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("per_input_length"), &t); + + concat_outputs.push_back(&ip_shape_op); + + Node& concat_training_node = graph.AddNode(graph.GenerateNodeName("ConcatTraining"), + "ConcatTraining", + "Concat with extra output", + concat_inputs, + concat_outputs, + &concat_node.GetAttributes(), + kMSDomain); + + // Assign provider to this new node. Provider should be same as the provider for old node. + concat_training_node.SetExecutionProviderType(concat_node.GetExecutionProviderType()); + graph_utils::FinalizeNodeFusion(graph, concat_training_node, concat_node); + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} + +bool ConcatReplacement::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const { + return true; +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/concat_replacement.h b/orttraining/orttraining/core/optimizer/concat_replacement.h new file mode 100644 index 0000000000..faa8877c56 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/concat_replacement.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class ConcatReplacement + +Rewrite rule that replaces Concat with ConcatTraining, that has an additional output +used in building the gradient for Concat node. + +It is attempted to be triggered only on nodes with op type "Concat". +*/ +class ConcatReplacement : public RewriteRule { + public: + ConcatReplacement() noexcept : RewriteRule("ConcatReplacement") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Concat"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index b49e144203..7dc99e7711 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -37,6 +37,7 @@ #include "core/session/inference_session.h" #include "orttraining/core/framework/distributed_run_context.h" #include "orttraining/core/optimizer/bias_dropout_fusion.h" +#include "orttraining/core/optimizer/concat_replacement.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" #include "orttraining/core/optimizer/localized_recompute.h" #include "orttraining/core/optimizer/megatron_transformer.h" @@ -101,6 +102,10 @@ std::vector> GeneratePreTrainingTransformers( case TransformerLevel::Level2: { // Put ReshapeFusion as level-2 optimization after all level-1 graph rewriters are run. transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); + rule_transformer = + onnxruntime::make_unique(optimizer_utils::GenerateRuleBasedTransformerName(level), + compatible_eps); + rule_transformer->Register(onnxruntime::make_unique()); } break; case TransformerLevel::Level3: { diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index 464ff515a7..f12fcfc802 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -248,7 +248,17 @@ inline Status GradientChecker::InitOpTesterWithGraph( for (size_t data_index = 0; data_index < y_infos.size(); data_index++) { std::string name = "output" + std::to_string(data_index); - op_session.AddOutput(name.c_str(), y_infos[data_index].shape.GetDims(), (*y_datas)[data_index]); + const std::vector& data = (*y_datas)[data_index]; + + if (y_infos[data_index].data_type == DataTypeImpl::GetTensorType()) { + std::vector int64_data(data.size()); + std::transform(data.begin(), data.end(), int64_data.begin(), [](Y_T x) { return static_cast(x); }); + op_session.AddOutput(name.c_str(), + y_infos[data_index].shape.GetDims(), + int64_data); + } else { + op_session.AddOutput(name.c_str(), y_infos[data_index].shape.GetDims(), data); + } } // Currently only allows setting int attributes to zero. TODO: Expand this for (auto attr : attributes) { @@ -568,7 +578,7 @@ inline Status GradientChecker::ComputeGradientError( } // Compute gradient error. - return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error, + return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error, attributes, check_not_have_gradient, check_not_have_shape_inferencing); } diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 773b4e1a8b..4c5abcfe51 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -39,55 +39,52 @@ static bool IsErrorWithinTolerance(float error, float tolerance) { EXPECT_IS_TINIER_THAN(max_error, 1.5e-2f) static void RunReductionTests(const OpDef& op_def) { - TestDataVector test_data( - // Input X - { - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - {{4, 3, 2}}, - }, - // Input Y - { - {{1, 1, 1}}, - {{}}, - {{1, 3, 1}}, - {{2}}, - {{4, 1, 2}}, - {{4, 3}}, - {{4, 1, 2}}, - {{4}} - }, - // Attributes - { - // default - {}, - // axes = [0, 1, 2], keepdims = 0 - {MakeAttribute("axes", std::vector{0, 1, 2}), - MakeAttribute("keepdims", int64_t(0))}, - // axes = [0, 2], keepdims = 1 - {MakeAttribute("axes", std::vector{0, 2})}, - // axes = [0, 1], keepdims = 0 - {MakeAttribute("axes", std::vector{0, 1}), - MakeAttribute("keepdims", int64_t(0))}, - // axes = [1], keepdims = 1 - {MakeAttribute("axes", std::vector{1}), - MakeAttribute("keepdims", int64_t(1))}, - // axes = [2], keepdims = 0 - {MakeAttribute("axes", std::vector{2}), - MakeAttribute("keepdims", int64_t(0))}, - // axes = [-2], keepdims = 1 - {MakeAttribute("axes", std::vector{-2}), - MakeAttribute("keepdims", int64_t(1))}, - // axes = [-2, -1], keepdims = 0 - {MakeAttribute("axes", std::vector{-2, -1}), - MakeAttribute("keepdims", int64_t(0))} - }); + // Input X + { + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + }, + // Input Y + { + {{1, 1, 1}}, + {{}}, + {{1, 3, 1}}, + {{2}}, + {{4, 1, 2}}, + {{4, 3}}, + {{4, 1, 2}}, + {{4}}}, + // Attributes + { + // default + {}, + // axes = [0, 1, 2], keepdims = 0 + {MakeAttribute("axes", std::vector{0, 1, 2}), + MakeAttribute("keepdims", int64_t(0))}, + // axes = [0, 2], keepdims = 1 + {MakeAttribute("axes", std::vector{0, 2})}, + // axes = [0, 1], keepdims = 0 + {MakeAttribute("axes", std::vector{0, 1}), + MakeAttribute("keepdims", int64_t(0))}, + // axes = [1], keepdims = 1 + {MakeAttribute("axes", std::vector{1}), + MakeAttribute("keepdims", int64_t(1))}, + // axes = [2], keepdims = 0 + {MakeAttribute("axes", std::vector{2}), + MakeAttribute("keepdims", int64_t(0))}, + // axes = [-2], keepdims = 1 + {MakeAttribute("axes", std::vector{-2}), + MakeAttribute("keepdims", int64_t(1))}, + // axes = [-2, -1], keepdims = 0 + {MakeAttribute("axes", std::vector{-2, -1}), + MakeAttribute("keepdims", int64_t(0))}}); GradientChecker gradient_checker; @@ -670,17 +667,25 @@ TEST(GradientCheckerTest, ConvGrad) { } } -TEST(GradientCheckerTest, ConcatGrad) { +static void TestConcatOpGrad(const std::string& op_type, + const std::string& domain = kOnnxDomain, + int opset_version = 9, + bool check_not_have_shape_inferencing = false) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"Concat"}; + const bool extra_input = op_type == "ConcatTraining"; + OpDef op_def{op_type, domain, opset_version}; //concat_1d { TensorShape x_shape({2}); TensorShape y_shape({6}); - gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(0))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(0))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -688,8 +693,12 @@ TEST(GradientCheckerTest, ConcatGrad) { { TensorShape x_shape({2, 2}); TensorShape y_shape({2, 6}); - gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(1))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(1))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -697,8 +706,12 @@ TEST(GradientCheckerTest, ConcatGrad) { { TensorShape x_shape({1, 2, 3}); TensorShape y_shape({1, 2, 9}); - gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(2))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({3}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x_shape, x_shape, x_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(2))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -707,8 +720,12 @@ TEST(GradientCheckerTest, ConcatGrad) { TensorShape x1_shape({2, 2}); TensorShape x2_shape({2, 4}); TensorShape y_shape({2, 6}); - gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(1))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({2}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(1))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -717,12 +734,24 @@ TEST(GradientCheckerTest, ConcatGrad) { TensorShape x1_shape({2, 2}); TensorShape x2_shape({2, 4}); TensorShape y_shape({2, 6}); - gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, {y_shape}, &max_error, - {MakeAttribute("axis", int64_t(-1))}); + std::vector output = {y_shape}; + if (extra_input) output.push_back(TensorInfo({2}, false, nullptr, DataTypeImpl::GetTensorType())); + gradient_checker.ComputeGradientError(op_def, {x1_shape, x2_shape}, + output, &max_error, + {MakeAttribute("axis", int64_t(-1))}, true, + check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } } +TEST(GradientCheckerTest, ConcatGrad) { + TestConcatOpGrad("Concat"); +} + +TEST(GradientCheckerTest, ConcatTrainingGrad) { /*also test w/o shape inferencing */ + TestConcatOpGrad("ConcatTraining", kMSDomain, 1, true); +} + TEST(GradientCheckerTest, AveragePoolGrad) { float max_error; GradientChecker gradient_checker; @@ -1909,4 +1938,3 @@ TEST(GradientCheckerTest, ExpandGrad) { } // namespace onnxruntime #endif // NDEBUG - diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 2e7ae8ffe9..bea6fc97e6 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -5,6 +5,7 @@ #include "gtest/gtest.h" #include "orttraining/core/optimizer/gist_encode_decode.h" #include "test/providers/provider_test_utils.h" +#include "test/framework/test_utils.h" #include "core/common/path_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/environment.h" @@ -28,6 +29,7 @@ namespace test { namespace { constexpr auto ORIGINAL_MODEL_PATH = ORT_TSTR("testdata/test_training_model.onnx"); constexpr auto BACKWARD_MODEL_PATH = ORT_TSTR("testdata/temp_backward_model.onnx"); +constexpr auto CONCAT_MODEL_PATH = ORT_TSTR("testdata/transform/concat_trainable.onnx"); std::unordered_set GetModelOutputNames(const InferenceSession& session) { const auto outputs_result = session.GetModelOutputs(); @@ -167,6 +169,27 @@ TEST(GradientGraphBuilderTest, BuildGradientGraphTest) { } } +TEST(GradientGraphBuilderTest, BuildConcatGradientGraphTest) { + const auto config = MakeBasicTrainingConfig(); + PathString backprop_model_file; + ASSERT_STATUS_OK(BuildBackPropGraph(CONCAT_MODEL_PATH, config, backprop_model_file)); + + std::shared_ptr pModel; + ASSERT_STATUS_OK(Model::Load(backprop_model_file, pModel, nullptr, DefaultLoggingManager().DefaultLogger())); + + Graph& graph = pModel->MainGraph(); + EXPECT_FALSE(graph.GraphResolveNeeded()); + EXPECT_TRUE(graph.NumberOfNodes() > 0); + EXPECT_TRUE(graph.MaxNodeIndex() > 0); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Concat"], 0); + ASSERT_EQ(op_to_count["Split"], 0); + ASSERT_EQ(op_to_count["ConcatTraining"], 1); + ASSERT_EQ(op_to_count["SplitTraining"], 1); +} + TEST(GradientGraphBuilderTest, TrainingSession_Basic) { const auto config = MakeBasicTrainingConfig(); PathString backprop_model_file; diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 873c7ade7e..484a2f2468 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -14,6 +14,7 @@ #include "orttraining/core/optimizer/gist_encode_decode.h" #include "orttraining/core/optimizer/nonzero_shape_setter.h" #include "orttraining/core/optimizer/megatron_transformer.h" +#include "orttraining/core/optimizer/concat_replacement.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" #include "test/util/include/asserts.h" @@ -108,8 +109,28 @@ TEST_F(GraphTransformationTests, NonZeroShapeSetter) { ASSERT_TRUE(nonzero_shape->dim(1).dim_param() == "nonzero_nonzero_count"); } -// MegatronF/G is defined only for training, and in msdomain. +// MegatronF/G and ConcatTraining is defined only for training, and in msdomain. #ifndef DISABLE_CONTRIB_OPS +TEST_F(GraphTransformationTests, ConcatReplacement) { + auto model_uri = MODEL_FOLDER "concat_trainable.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); + Graph& graph = p_model->MainGraph(); + + auto rule_transformer_L1 = onnxruntime::make_unique("ConcatReplacement"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Concat"], 0); + ASSERT_EQ(op_to_count["ConcatTraining"], 1); +} + TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) { auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx"; std::shared_ptr p_model; @@ -483,7 +504,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) { auto model_uri = MODEL_FOLDER "model_parallel/self_attention_megatron_basic_test.onnx"; - const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here. + const int total_rank = 2; // The test graph is too small to partition to 4, so use 2 instead here. std::vector graphs; std::vector> p_models(total_rank); for (auto i = 0; i < total_rank; i++) {