From a7b6fc08f214e8ae0ff4c0a1ee639bfd1efdbc7b Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Wed, 10 Feb 2021 17:50:10 -0800 Subject: [PATCH] Support skiplayernorm fusion without beta in layernorm (#6617) * support skiplayernorm fusion without beta in layernorm * use place holder * review comments --- .../core/optimizer/skip_layer_norm_fusion.cc | 5 +++-- .../test/optimizer/graph_transform_test.cc | 17 +++++++++++++++++ .../fusion/skip_layer_norm_no_beta.onnx | Bin 0 -> 274 bytes 3 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta.onnx diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index 3fa535d084..5c979032d7 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -217,11 +217,13 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le continue; } + NodeArg beta_place_holder("", nullptr); + // Get the inputs for the new SkipLayerNormalization node. std::vector skip_layer_norm_input_defs{p_add1->MutableInputDefs()[0], p_add1->MutableInputDefs()[1], ln_node.MutableInputDefs()[1], - ln_node.MutableInputDefs()[2]}; + ln_node.MutableInputDefs().size() == 2 ? &beta_place_holder : ln_node.MutableInputDefs()[2]}; if (matched_format == Format::Format1) { skip_layer_norm_input_defs[0] = p_add2->MutableInputDefs()[0]; @@ -241,7 +243,6 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le "fused SkipLayerNorm subgraphs ", skip_layer_norm_input_defs, ln_node.MutableOutputDefs(), {}, kMSDomain); - // Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value // will be used. NodeAttributes ln_attrs = ln_node.GetAttributes(); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index ffcb01192d..c026d691f9 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2751,6 +2751,23 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) { } } +TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) { + auto model_uri = MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 1); +} + TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) { auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format1.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a7f084f8b215048cbf04ba36e9cc1cc424bb0cc6 GIT binary patch literal 274 zcmd;Jx9Vi#lH+1@4~|yi9(Fd zQrtd?m8nH}`9--vl|o>SACQxnlUbEml9`{UCBns$Sdm#Q@gD{l7cequ32?Ec78C>J zMb{ivH($WWC?pe{oml|aAD;#?q{JjXFFz@@BvDJ2gGGQ*Nq~W&!JdJE!2yVXxWT~* q;*cnDpzDQ%xCA&Dg?PA_I0S%L0f<>Z5>D78l7JG<*d&~o1b6_@I6^r9 literal 0 HcmV?d00001