EmbedLayerNormalization fusion (#2452)

Embed Layer Normalization Fusion
This commit is contained in:
liuziyue 2019-11-28 14:03:58 -08:00 committed by GitHub
parent 60208463a9
commit 0edd4ef6ca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 341 additions and 5 deletions

View file

@ -0,0 +1,270 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/initializer.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/graph/graph_utils.h"
#include "float.h"
#define DEBUG_LOG(x) LOGS(logger, VERBOSE) << x
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
// Add a Cast to convert Input from int64 to int32.
static NodeArg* CastToInt32(Graph& graph, NodeArg* input, ProviderType provider_type) {
const TensorShapeProto* input_shape = input->Shape();
TypeProto input_int32;
input_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
auto dim0 = input_int32.mutable_tensor_type()->mutable_shape()->add_dim();
*dim0 = input_shape->dim(0);
auto dim1 = input_int32.mutable_tensor_type()->mutable_shape()->add_dim();
*dim1 = input_shape->dim(1);
auto& cast32 = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(input->Name() + "_Int32"), &input_int32);
Node& node = graph.AddNode(graph.GenerateNodeName(input->Name() + "_Cast"),
"Cast",
"Cast Input from int64 to int32",
{input},
{&cast32},
nullptr,
kOnnxDomain);
// Add attribute: "to" = 6
ONNX_NAMESPACE::AttributeProto to;
to.set_name("to");
to.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
to.set_i(static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
node.AddAttribute("to", to);
node.SetExecutionProviderType(provider_type);
return &cast32;
}
static NodeArg* CheckInput(Graph& graph, NodeArg* input, ProviderType provider_type, const logging::Logger& logger) {
// Validate input shape (batch_size, sequence_length) and data type.
// Note that batch_size and sequence_length could be symbolic.
const TensorShapeProto* input_shape = input->Shape();
if (input_shape == nullptr || input_shape->dim_size() != 2 || input->Type() == nullptr) {
DEBUG_LOG("Mask shape is unknown or not 2D, or data type unknown");
return nullptr;
}
auto data_type = input->TypeAsProto()->tensor_type().elem_type();
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64 &&
data_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) {
DEBUG_LOG("Input data type is not int32 or int64");
return nullptr;
}
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
return CastToInt32(graph, input, provider_type);
}
return input;
}
/**
Embed Layer Normalization will fuse embeddings and mask processing into one node :
The embeddings before conversion:
(input_ids) --------> Gather ----------+ (segment_ids)
| | |
| v v
+--> Shape --> Expand -> Gather---->Add Gather
| ^ | |
| | v v
+---(optional graph) SkipLayerNormalization
*/
Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
auto* p_layer_norm = graph.GetNode(node_index);
if (p_layer_norm == nullptr)
continue; // we removed the node as part of an earlier fusion
Node& layer_norm_node = *p_layer_norm;
ORT_RETURN_IF_ERROR(Recurse(layer_norm_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(layer_norm_node, "LayerNormalization", {1}, kOnnxDomain) ||
!graph_utils::IsSupportedProvider(layer_norm_node, GetCompatibleExecutionProviders())) {
continue;
}
// Find Attention after SkipLayerNormalization
const Node* p_attention = graph_utils::FirstChildByType(layer_norm_node, "Attention");
// Stop EmbedLayerNormalization fusion if Attention is not found.
if (p_attention == nullptr) {
return Status::OK();
}
Node& attention_node = *graph.GetNode(p_attention->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(attention_node, "Attention", {1}, kMSDomain) ||
!graph_utils::IsSupportedProvider(attention_node, GetCompatibleExecutionProviders())) {
continue;
}
// Find ReduceSum --> Attention
std::vector<const Node::EdgeEnd*> edges;
if (!graph_utils::FindPath(attention_node, true, {{0, 3, "ReduceSum", {1, 11}, kOnnxDomain}}, edges, logger)) {
continue;
}
Node& reduce_sum_node = *graph.GetNode(edges[0]->GetNode().Index());
// Find Add --> LayerNormalization
if (!graph_utils::FindPath(layer_norm_node, true, {{0, 0, "Add", {7}, kOnnxDomain}}, edges, logger)) {
continue;
}
Node& layer_norm_add_node = *graph.GetNode(edges[0]->GetNode().Index());
// Traceback the SkipLayerNormalization node to find Gather --> SkipLayerNormalization
std::vector<graph_utils::EdgeEndToMatch> segment_embedding_path{
{0, 1, "Gather", {1, 11}, kOnnxDomain}};
if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) {
continue;
}
Node& segment_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
if (segment_gather_node.GetOutputEdgesCount() != 1) {
continue;
}
// The first input of segment_gather_node must be 2d.
auto sg_shape = segment_gather_node.MutableInputDefs()[0]->Shape();
if (sg_shape != nullptr && sg_shape->dim_size() != 2) {
continue;
}
// Traceback the SkipLayerNormalization node to find Gather --> Add --> SkipLayerNormalization
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
{0, 0, "Add", {7}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain}};
if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) {
continue;
}
Node& add_node = *graph.GetNode(edges[0]->GetNode().Index());
Node& word_gather_node = *graph.GetNode(edges[1]->GetNode().Index());
if (add_node.GetOutputEdgesCount() != 1 || word_gather_node.GetOutputEdgesCount() != 1) {
continue;
}
// The first input of word_gather_node must be 2d.
auto wg_shape = word_gather_node.MutableInputDefs()[0]->Shape();
if (wg_shape != nullptr && wg_shape->dim_size() != 2) {
continue;
}
// Traceback the Add node to find (Shape --> Expand -->) Gather --> Add.
// Constant folding removes Shape and Expand nodes when input does not have symbolic shape. In that
// case just look for Gather --> Add.
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path{
{0, 1, "Gather", {1, 11}, kOnnxDomain}};
if (!graph_utils::FindPath(add_node, true, position_embedding_path, edges, logger)) {
continue;
}
Node& position_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
if (position_gather_node.GetOutputEdgesCount() != 1) {
continue;
}
// The first input of position_gather_node must be 2d.
auto pg_shape = position_gather_node.MutableInputDefs()[0]->Shape();
if (pg_shape != nullptr && pg_shape->dim_size() != 2) {
continue;
}
// Match Shape --> Expand path if needed.
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path_symbolic{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 1, "Shape", {1}, kOnnxDomain}};
Node* p_expand_node = nullptr;
Node* p_shape_node = nullptr;
if (graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) {
if (edges[0]->GetNode().GetOutputEdgesCount() == 1 && edges[1]->GetNode().GetOutputEdgesCount() == 1) {
p_expand_node = graph.GetNode(edges[0]->GetNode().Index());
p_shape_node = graph.GetNode(edges[1]->GetNode().Index());
}
}
// Get input "input_ids" from node.
NodeArg* input_ids = CheckInput(graph, word_gather_node.MutableInputDefs()[1], layer_norm_node.GetExecutionProviderType(), logger);
if (input_ids == nullptr) {
DEBUG_LOG("Input id is not valid. ");
continue;
}
// Get input "segment_ids" from node.
NodeArg* segment_ids = CheckInput(graph, segment_gather_node.MutableInputDefs()[1], layer_norm_node.GetExecutionProviderType(), logger);
if (segment_ids == nullptr) {
DEBUG_LOG("Segment id is not valid. ");
continue;
}
// Get input "mask" from "ReduceSum" node.
NodeArg* mask = CheckInput(graph, reduce_sum_node.MutableInputDefs()[0], layer_norm_node.GetExecutionProviderType(), logger);
if (mask == nullptr) {
DEBUG_LOG("Mask is not valid. ");
continue;
}
const std::vector<NodeArg*> embed_layer_norm_input_defs{
input_ids,
segment_ids,
mask,
word_gather_node.MutableInputDefs()[0],
position_gather_node.MutableInputDefs()[0],
segment_gather_node.MutableInputDefs()[0],
layer_norm_node.MutableInputDefs()[1],
layer_norm_node.MutableInputDefs()[2]};
Node& embed_layer_norm_node = graph.AddNode(graph.GenerateNodeName("EmbedLayerNormalization"),
"EmbedLayerNormalization",
"fused EmbedLayerNorm subgraphs ",
embed_layer_norm_input_defs,
{layer_norm_node.MutableOutputDefs()[0], reduce_sum_node.MutableOutputDefs()[0]},
{}, kMSDomain);
// Assign provider to this new node. Provider should be same as the provider for old node.
embed_layer_norm_node.SetExecutionProviderType(layer_norm_node.GetExecutionProviderType());
// move input edges to gather (first in list) across to the embed_layer_norm_node.
// move output definitions and output edges to embed_layer_norm_node.
// remove all the other nodes.
std::vector<NodeIndex> nodes_to_remove;
if (p_shape_node != nullptr && p_expand_node != nullptr) {
// Match Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand
if (p_expand_node != nullptr) {
Node& expand_node = *graph.GetNode(p_expand_node->Index());
std::vector<graph_utils::EdgeEndToMatch> expand_parent_path{
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Cast", {9}, kOnnxDomain},
{0, 0, "Squeeze", {1}, kOnnxDomain},
{0, 0, "Transpose", {1}, kOnnxDomain},
{0, 0, "NonZero", {9}, kOnnxDomain},
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain},
};
if (graph_utils::FindPath(expand_node, true, expand_parent_path, edges, logger)) {
for (size_t i = 0; i < edges.size(); i++) {
if (edges[i]->GetNode().GetOutputEdgesCount() != 1) {
nodes_to_remove.clear();
break;
}
nodes_to_remove.push_back(edges[i]->GetNode().Index());
}
}
}
nodes_to_remove.push_back(p_shape_node->Index());
nodes_to_remove.push_back(p_expand_node->Index());
}
nodes_to_remove.push_back(word_gather_node.Index());
nodes_to_remove.push_back(position_gather_node.Index());
nodes_to_remove.push_back(segment_gather_node.Index());
nodes_to_remove.push_back(add_node.Index());
nodes_to_remove.push_back(reduce_sum_node.Index());
nodes_to_remove.push_back(layer_norm_add_node.Index());
nodes_to_remove.push_back(layer_norm_node.Index());
for (const auto& index : nodes_to_remove) {
Node* node = graph.GetNode(index);
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
}
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
@Class EmbedLayerNormFusion
Rewrite graph fusing embeddings and mask processing into one node.
*/
class EmbedLayerNormFusion : public GraphTransformer {
public:
EmbedLayerNormFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("EmbedLayerNormFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -23,6 +23,7 @@
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/attention_fusion.h"
#include "core/mlas/inc/mlas.h"
@ -128,6 +129,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<AttentionFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<EmbedLayerNormFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<BiasGelu>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<SkipLayerNormFusion>(cpu_cuda_execution_providers));

View file

@ -8,12 +8,9 @@
namespace onnxruntime {
/**
@Class LayerNormFusion
@Class SkipLayerNormFusion
Rewrite graph fusing Layer Normalization subgraph to a single LayerNormalization node.
The formula corresponding to LayerNorm activation subgraph:
(x - mean(x, axis)) / sqrt(var(x, axis)) * scale + bias, where x is the input.
Rewrite graph fusing Add + Layer Normalization subgraph to a single SkipLayerNormalization node.
*/
class SkipLayerNormFusion : public GraphTransformer {

View file

@ -19,6 +19,7 @@
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/graph_transformer_mgr.h"
#include "core/optimizer/identity_elimination.h"
@ -1253,6 +1254,48 @@ TEST(GraphTransformationTests, SkipLayerNormFusionTest) {
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx");
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx");
}
TEST(GraphTransformationTests, EmbedLayerNormFusionFormat1) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format1.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Gather"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["ReduceSum"] == 0);
ASSERT_TRUE(op_to_count["Attention"] == 1);
ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == 0);
ASSERT_TRUE(op_to_count["EmbedLayerNormalization"] == 1);
}
TEST(GraphTransformationTests, EmbedLayerNormFusionFormat2) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format2.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Shape"] == 0);
ASSERT_TRUE(op_to_count["Expand"] == 0);
ASSERT_TRUE(op_to_count["Gather"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["ReduceSum"] == 0);
ASSERT_TRUE(op_to_count["Attention"] == 1);
ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == 0);
ASSERT_TRUE(op_to_count["EmbedLayerNormalization"] == 1);
}
#endif
} // namespace test