Skip layer norm transform (#2350)

* skip layer normalization transformer
This commit is contained in:
liuziyue 2019-11-13 13:46:09 -08:00 committed by GitHub
parent 8ed2928dd5
commit ffa2812587
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 159 additions and 1 deletions

View file

@ -21,6 +21,7 @@
#include "core/optimizer/add_gelu_fusion.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/mlas/inc/mlas.h"
#include "core/session/inference_session.h"
@ -125,7 +126,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
std::unordered_set<std::string> cuda_execution_providers = {onnxruntime::kCudaExecutionProvider};
transformers.emplace_back(onnxruntime::make_unique<AddGeluFusion>(cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<SkipLayerNormFusion>(cuda_execution_providers));
#endif
} break;

View file

@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/initializer.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/graph/graph_utils.h"
#include "float.h"
#include <deque>
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
// LayerNorm supports limited data types.
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)"};
static bool IsSupportedDataType(const Node& node) {
for (const auto& input_arg : node.InputDefs()) {
if (std::find(supported_data_types.begin(), supported_data_types.end(),
*(input_arg->Type())) == supported_data_types.end()) {
return false;
}
}
return true;
}
/**
Skip Layer Normalization will fuse Add + LayerNormalization into one node.
*/
Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::vector<std::reference_wrapper<Node>> nodes_to_remove;
for (auto node_index : node_topology_list) {
nodes_to_remove.clear();
auto* p_add = graph.GetNode(node_index);
if (p_add == nullptr)
continue; // we removed the node as part of an earlier fusion.
Node& add_node = *p_add;
ORT_RETURN_IF_ERROR(Recurse(add_node, modified, graph_level));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
!graph_utils::IsSupportedProvider(add_node, GetCompatibleExecutionProviders()) ||
add_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(add_node)) {
continue;
}
// Check the input dimensions of the "Add" node.
const TensorShapeProto* add_input1_shape = add_node.MutableInputDefs()[0]->Shape();
const TensorShapeProto* add_input2_shape = add_node.MutableInputDefs()[1]->Shape();
if (add_input1_shape == nullptr || add_input2_shape == nullptr) {
continue;
}
// "Add" inputs have to be 3d.
if (add_input1_shape->dim_size() != 3 || add_input2_shape->dim_size() != 3) {
continue;
}
// "Add" inputs have to be of same dimensions.
bool isValidInput = true;
for (int i = 0; i < 3; i++) {
if (add_input1_shape->dim(i).dim_value() != add_input2_shape->dim(i).dim_value()) {
isValidInput = false;
break;
}
}
if (!isValidInput) {
continue;
}
nodes_to_remove.push_back(add_node);
// Find "LayerNormalization" node after the "Add".
Node& ln_node = *graph.GetNode(add_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(ln_node, "LayerNormalization", {1}) ||
ln_node.GetExecutionProviderType() != add_node.GetExecutionProviderType() ||
!IsSupportedDataType(ln_node)) {
continue;
}
nodes_to_remove.push_back(ln_node);
// Get the inputs for the new SkipLayerNormalization node.
const std::vector<NodeArg*> skip_layer_norm_input_defs{add_node.MutableInputDefs()[0],
add_node.MutableInputDefs()[1],
ln_node.MutableInputDefs()[1],
ln_node.MutableInputDefs()[2]};
Node& skip_layer_norm_node = graph.AddNode(graph.GenerateNodeName("SkipLayerNormalization"),
"SkipLayerNormalization",
"fused SkipLayerNorm subgraphs ",
skip_layer_norm_input_defs,
{}, {}, kMSDomain);
// Assign provider to this new node. Provider should be same as the provider for old node.
skip_layer_norm_node.SetExecutionProviderType(add_node.GetExecutionProviderType());
// move input edges to add (first in list) across to the layer_norm_node.
// move output definitions and output edges from mul_node (last in list) to layer_norm_node.
// remove all the other nodes.
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, skip_layer_norm_node);
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
@Class LayerNormFusion
Rewrite graph fusing Layer Normalization subgraph to a single LayerNormalization node.
The formula corresponding to LayerNorm activation subgraph:
(x - mean(x, axis)) / sqrt(var(x, axis)) * scale + bias, where x is the input.
*/
class SkipLayerNormFusion : public GraphTransformer {
public:
SkipLayerNormFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("SkipLayerNormFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
};
} // namespace onnxruntime

View file

@ -17,6 +17,7 @@
#include "core/optimizer/add_gelu_fusion.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/graph_transformer_mgr.h"
#include "core/optimizer/identity_elimination.h"
@ -913,6 +914,29 @@ TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) {
}
}
}
TEST(GraphTransformationTests, SkipLayerNormFusionTest) {
string model_uri = MODEL_FOLDER + "fusion/skip_layer_norm.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<LayerNormFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(onnxruntime::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Sub"] == 0);
ASSERT_TRUE(op_to_count["ReduceMean"] == 0);
ASSERT_TRUE(op_to_count["Pow"] == 0);
ASSERT_TRUE(op_to_count["Sqrt"] == 0);
ASSERT_TRUE(op_to_count["LayerNormalization"] == 0);
ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == 1);
}
#endif
} // namespace test

Binary file not shown.