diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index ebb565af4a..c76781d28f 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -25,7 +25,8 @@ static bool IsSupportedDataType(const Node& node) { static bool CheckFirstAdd(Node& add, ProviderType providertype) { if (providertype != add.GetExecutionProviderType() || - !IsSupportedDataType(add)) { + !IsSupportedDataType(add) || + add.GetOutputEdgesCount() != 1) { return false; } @@ -58,7 +59,8 @@ static bool CheckFirstAdd(Node& add, ProviderType providertype) { // The 2nd input should be a 1D constant value static bool CheckSecondAdd(Node& add, ProviderType providertype) { if (providertype != add.GetExecutionProviderType() || - !IsSupportedDataType(add)) { + !IsSupportedDataType(add) || + add.GetOutputEdgesCount() != 1) { return false; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index ba3b01e655..12d59168e3 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1272,7 +1272,7 @@ TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) { } } -static void TestSkipLayerNormFusion(const std::basic_string& file_path) { +static void TestSkipLayerNormFusion(const std::basic_string& file_path, int add_count, int ln_count, int skip_ln_count) { std::shared_ptr p_model; ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph = p_model->MainGraph(); @@ -1285,19 +1285,22 @@ static void TestSkipLayerNormFusion(const std::basic_string& file_pat std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Div"] == 0); - ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Add"] == add_count ); ASSERT_TRUE(op_to_count["Sub"] == 0); ASSERT_TRUE(op_to_count["ReduceMean"] == 0); ASSERT_TRUE(op_to_count["Pow"] == 0); ASSERT_TRUE(op_to_count["Sqrt"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); - ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == 1); + ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count ); + ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == skip_ln_count ); } TEST(GraphTransformationTests, SkipLayerNormFusionTest) { - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx"); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx"); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx"); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1 ); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1 ); + TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1 ); + TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1 ); + TestSkipLayerNormFusion( MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0 ); } TEST(GraphTransformationTests, EmbedLayerNormFusionFormat1) { diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx index 4e72ab0dd5..b0a260e8f4 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial.onnx new file mode 100644 index 0000000000..e01484fc66 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1_partial.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx index 501bf2a5e9..91fdb98282 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial.onnx new file mode 100644 index 0000000000..fa293af633 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2_partial.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx index 8259df0b2b..e5d2729e85 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion.onnx new file mode 100644 index 0000000000..4e61e1439c Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3_no_fusion.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py index 1d8afa4b58..b547d650d8 100644 --- a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py @@ -8,7 +8,7 @@ class Format(Enum): Format2=2, Format3=3 -def GenerateModel(format, model_name): +def GenerateModel(format, model_name, multi_output_add = False): nodes = [ # LayerNorm subgraph helper.make_node("ReduceMean", ["ln_in"], ["rd1_out"], "reduce1", axes=[-1], keepdims=1), helper.make_node("Sub", ["ln_in", "rd1_out"], ["sb1_out"], "sub1"), @@ -42,6 +42,10 @@ def GenerateModel(format, model_name): elif format is Format.Format3: nodes.extend([helper.make_node("Add", ["A", "B"], ["ln_in"], "add2"),]) + if multi_output_add: + neg_input = "ln_in" if format is Format.Format3 else "add3_out" + nodes.extend([helper.make_node("Neg", [neg_input], ["neg_out"], "neg")]) + graph = helper.make_graph( nodes, "SkipLayerNorm_format3", #name @@ -60,4 +64,7 @@ def GenerateModel(format, model_name): GenerateModel(Format.Format1, 'skip_layer_norm_format1.onnx') GenerateModel(Format.Format2, 'skip_layer_norm_format2.onnx') -GenerateModel(Format.Format3, 'skip_layer_norm_format3.onnx') \ No newline at end of file +GenerateModel(Format.Format3, 'skip_layer_norm_format3.onnx') +GenerateModel(Format.Format1, 'skip_layer_norm_format1_partial.onnx', True) +GenerateModel(Format.Format2, 'skip_layer_norm_format2_partial.onnx', True) +GenerateModel(Format.Format3, 'skip_layer_norm_format3_no_fusion.onnx', True) \ No newline at end of file