From cd876720d930fcea33c2c3f3bb9717ae115c007c Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 24 Jan 2020 13:47:34 -0800 Subject: [PATCH] Only fuse when output count of add is 1 (#2884) * Only fuse when output count of add is 1 * add unit test for add with multi output --- .../core/optimizer/skip_layer_norm_fusion.cc | 6 ++++-- .../test/optimizer/graph_transform_test.cc | 17 ++++++++++------- .../fusion/skip_layer_norm_format1.onnx | Bin 764 -> 764 bytes .../skip_layer_norm_format1_partial.onnx | Bin 0 -> 795 bytes .../fusion/skip_layer_norm_format2.onnx | Bin 764 -> 764 bytes .../skip_layer_norm_format2_partial.onnx | Bin 0 -> 795 bytes .../fusion/skip_layer_norm_format3.onnx | Bin 695 -> 695 bytes .../skip_layer_norm_format3_no_fusion.onnx | Bin 0 -> 723 bytes .../transform/fusion/skip_layer_norm_gen.py | 11 +++++++++-- 9 files changed, 23 insertions(+), 11 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion.onnx diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index ebb565af4a..c76781d28f 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -25,7 +25,8 @@ static bool IsSupportedDataType(const Node& node) { static bool CheckFirstAdd(Node& add, ProviderType providertype) { if (providertype != add.GetExecutionProviderType() || - !IsSupportedDataType(add)) { + !IsSupportedDataType(add) || + add.GetOutputEdgesCount() != 1) { return false; } @@ -58,7 +59,8 @@ static bool CheckFirstAdd(Node& add, ProviderType providertype) { // The 2nd input should be a 1D constant value static bool CheckSecondAdd(Node& add, ProviderType providertype) { if (providertype != add.GetExecutionProviderType() || - !IsSupportedDataType(add)) { + !IsSupportedDataType(add) || + add.GetOutputEdgesCount() != 1) { return false; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index ba3b01e655..12d59168e3 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1272,7 +1272,7 @@ TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) { } } -static void TestSkipLayerNormFusion(const std::basic_string& file_path) { +static void TestSkipLayerNormFusion(const std::basic_string& file_path, int add_count, int ln_count, int skip_ln_count) { std::shared_ptr p_model; ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); @@ -1285,19 +1285,22 @@ static void TestSkipLayerNormFusion(const std::basic_string& file_pat std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Div"] == 0); - ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Add"] == add_count ); ASSERT_TRUE(op_to_count["Sub"] == 0); ASSERT_TRUE(op_to_count["ReduceMean"] == 0); ASSERT_TRUE(op_to_count["Pow"] == 0); ASSERT_TRUE(op_to_count["Sqrt"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); - ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == 1); + ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count ); + ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == skip_ln_count ); } TEST(GraphTransformationTests, SkipLayerNormFusionTest) { - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx"); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx"); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx"); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1 ); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1 ); + TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1 ); + TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1 ); + TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0 ); } TEST(GraphTransformationTests, EmbedLayerNormFusionFormat1) { diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx index 4e72ab0dd5f90c554581658e40af3619321f1d7d..b0a260e8f45bc8e1ace795d8d5b3278c51be8939 100644 GIT binary patch delta 13 Ucmeyv`iGT?gKZV!Z delta 13 Ucmeyv`iGT?gLNa*PbNk#03fyl%K!iX diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e01484fc6635b8c598fc32c64ebd6e22daa2843b GIT binary patch literal 795 zcmbVJJx}965Vd22&47fpa5{ttAcH`dhB)0xmkWk6M2K`W)*3I7MQjIu5a{v)uDT8? ze#rfV@S9+EHx5lif#qF2z4tsbZ)`69`nL!jSQsRJln^5e8-7|An?}aM@`g2RxDw!k z(ZpE?oj$U>{a!ult7H%!mlz!JDK!MPcA1W#3l+&T$THIm?sqaA@!@6PrJFR#3!gPyV$yKB zADtsxI}JnFg&9rE{5~tD84%fe4{tc+A;bb;W*|lwJxC`4EOj0|z#c60X&h6iL}J9{ zKt@GiYZv7J>QIrO2dtn(J=^5toXByuK|MtwqPZBiDSM#al=_YsSJ0q(ME-bD09VfQ z)ISZxNLXLCM__AR)+axDchRszpIFvSv)I4Iixy3XSgNfniqiIG|H)b7+~oh?WAAf^ zPTkazwiIRjhuiVRGj4k_xw9qZx5?!9+oaQad46soZh0c=7b@nx{)y2E0n1d$3e-#s U&n_OF3*I~-K;dkj;MJ_<5A?a<@c;k- literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx index 501bf2a5e9d1eba9164932db7dcbae181a4ab5ab..91fdb98282ce0d5674ae4eba5b47d504c3a7259e 100644 GIT binary patch delta 13 Ucmeyv`iGT?gKZV!Z delta 13 Ucmeyv`iGT?gLNa*PbNk#03fyl%K!iX diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fa293af6332b41a395e87ffc3552b47b3a99a856 GIT binary patch literal 795 zcmbVJO;6iE5Vd0myHg~rEr<#=ZDbIr?u9rvBtkGfh6s^cFV-3_k(Jm^{gJdcegG$K zaN>vX6T)w5XLsYkA>x4LT|K=wo_TL5(Q?6&0$)Z0@32I{GBEEFPtyl8+c z=Xn~hR3l+SIUa$n_1TbY4Mx$VNB^>{pJuUtjTbH24zW~QR}`h|E&r4AiF5O(+aKP` z4xNUnBW)?l%?U5Z7tc-ClgXVgQhuJ#kDuqg&cpqE2XV&}QQuH8@2_o)MhIA@O4gxf WT6mgxbS`+SgaC!JRf1QuKK=*pY~b+# literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx index 8259df0b2b2b6407ffbcfcc8f8263ace56c4a25b..e5d2729e85931d120d6f7cdea45d95a8e13b1746 100644 GIT binary patch delta 13 Ucmdnax}BAYgKZGcfpgvPN^!y;%Mxq7>MK>tB5hZbXTRF+xrY`j|cX5@OU(TvbUUdj~ zHfKQ{Mq~77n?zxVxs(v-(p@x$E|et8z}plV%>QIKX6E_5%a2i*q&97uM0LDkJ~BhP zbme)l2Ng}!0*uH`W<9`+7gI{w8DWsxs1=AWSY