diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index d26ddf86c4..75b380eebd 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -21,6 +21,7 @@ #include "core/optimizer/add_gelu_fusion.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/layer_norm_fusion.h" +#include "core/optimizer/skip_layer_norm_fusion.h" #include "core/mlas/inc/mlas.h" #include "core/session/inference_session.h" @@ -125,7 +126,7 @@ std::vector> GenerateTransformers(TransformerL std::unordered_set cuda_execution_providers = {onnxruntime::kCudaExecutionProvider}; transformers.emplace_back(onnxruntime::make_unique(cuda_execution_providers)); - + transformers.emplace_back(onnxruntime::make_unique(cuda_execution_providers)); #endif } break; diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc new file mode 100644 index 0000000000..d28705a4d6 --- /dev/null +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/optimizer/initializer.h" +#include "core/optimizer/skip_layer_norm_fusion.h" +#include "core/graph/graph_utils.h" +#include "float.h" +#include + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; +namespace onnxruntime { + +// LayerNorm supports limited data types. +static std::vector supported_data_types{"tensor(float16)", "tensor(float)"}; + +static bool IsSupportedDataType(const Node& node) { + for (const auto& input_arg : node.InputDefs()) { + if (std::find(supported_data_types.begin(), supported_data_types.end(), + *(input_arg->Type())) == supported_data_types.end()) { + return false; + } + } + return true; +} + +/** +Skip Layer Normalization will fuse Add + LayerNormalization into one node. +*/ +Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + std::vector> nodes_to_remove; + for (auto node_index : node_topology_list) { + nodes_to_remove.clear(); + auto* p_add = graph.GetNode(node_index); + if (p_add == nullptr) + continue; // we removed the node as part of an earlier fusion. + + Node& add_node = *p_add; + ORT_RETURN_IF_ERROR(Recurse(add_node, modified, graph_level)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) || + !graph_utils::IsSupportedProvider(add_node, GetCompatibleExecutionProviders()) || + add_node.GetOutputEdgesCount() != 1 || + !IsSupportedDataType(add_node)) { + continue; + } + + // Check the input dimensions of the "Add" node. + const TensorShapeProto* add_input1_shape = add_node.MutableInputDefs()[0]->Shape(); + const TensorShapeProto* add_input2_shape = add_node.MutableInputDefs()[1]->Shape(); + + if (add_input1_shape == nullptr || add_input2_shape == nullptr) { + continue; + } + // "Add" inputs have to be 3d. + if (add_input1_shape->dim_size() != 3 || add_input2_shape->dim_size() != 3) { + continue; + } + // "Add" inputs have to be of same dimensions. + bool isValidInput = true; + for (int i = 0; i < 3; i++) { + if (add_input1_shape->dim(i).dim_value() != add_input2_shape->dim(i).dim_value()) { + isValidInput = false; + break; + } + } + if (!isValidInput) { + continue; + } + + nodes_to_remove.push_back(add_node); + + // Find "LayerNormalization" node after the "Add". + Node& ln_node = *graph.GetNode(add_node.OutputNodesBegin()->Index()); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(ln_node, "LayerNormalization", {1}) || + ln_node.GetExecutionProviderType() != add_node.GetExecutionProviderType() || + !IsSupportedDataType(ln_node)) { + continue; + } + nodes_to_remove.push_back(ln_node); + + // Get the inputs for the new SkipLayerNormalization node. + const std::vector skip_layer_norm_input_defs{add_node.MutableInputDefs()[0], + add_node.MutableInputDefs()[1], + ln_node.MutableInputDefs()[1], + ln_node.MutableInputDefs()[2]}; + Node& skip_layer_norm_node = graph.AddNode(graph.GenerateNodeName("SkipLayerNormalization"), + "SkipLayerNormalization", + "fused SkipLayerNorm subgraphs ", + skip_layer_norm_input_defs, + {}, {}, kMSDomain); + + // Assign provider to this new node. Provider should be same as the provider for old node. + skip_layer_norm_node.SetExecutionProviderType(add_node.GetExecutionProviderType()); + + // move input edges to add (first in list) across to the layer_norm_node. + // move output definitions and output edges from mul_node (last in list) to layer_norm_node. + // remove all the other nodes. + graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, skip_layer_norm_node); + + modified = true; + } + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.h b/onnxruntime/core/optimizer/skip_layer_norm_fusion.h new file mode 100644 index 0000000000..1052235dfd --- /dev/null +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class LayerNormFusion + +Rewrite graph fusing Layer Normalization subgraph to a single LayerNormalization node. + +The formula corresponding to LayerNorm activation subgraph: +(x - mean(x, axis)) / sqrt(var(x, axis)) * scale + bias, where x is the input. + +*/ +class SkipLayerNormFusion : public GraphTransformer { + public: + SkipLayerNormFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept + : GraphTransformer("SkipLayerNormFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 0a568eb829..d71c400d32 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -17,6 +17,7 @@ #include "core/optimizer/add_gelu_fusion.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/layer_norm_fusion.h" +#include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/graph_transformer.h" #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/identity_elimination.h" @@ -913,6 +914,29 @@ TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) { } } } + +TEST(GraphTransformationTests, SkipLayerNormFusionTest) { + string model_uri = MODEL_FOLDER + "fusion/skip_layer_norm.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Sub"] == 0); + ASSERT_TRUE(op_to_count["ReduceMean"] == 0); + ASSERT_TRUE(op_to_count["Pow"] == 0); + ASSERT_TRUE(op_to_count["Sqrt"] == 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); + ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == 1); +} #endif } // namespace test diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm.onnx new file mode 100644 index 0000000000..3017b5e43d Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm.onnx differ