mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
EmbedLayerNormalization Fusion Improvement (#2553)
Embedding layer norm fusion improvements - add more checks
This commit is contained in:
parent
0f12346d76
commit
200f4b4ea6
4 changed files with 154 additions and 56 deletions
|
|
@ -3,6 +3,8 @@
|
|||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/embed_layer_norm_fusion.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "float.h"
|
||||
|
||||
#define DEBUG_LOG(x) LOGS(logger, VERBOSE) << x
|
||||
|
|
@ -13,6 +15,10 @@ namespace onnxruntime {
|
|||
|
||||
// Add a Cast to convert Input from int64 to int32.
|
||||
static NodeArg* CastToInt32(Graph& graph, NodeArg* input, ProviderType provider_type) {
|
||||
auto data_type = input->TypeAsProto()->tensor_type().elem_type();
|
||||
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
|
||||
return input;
|
||||
}
|
||||
const TensorShapeProto* input_shape = input->Shape();
|
||||
TypeProto input_int32;
|
||||
input_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
|
||||
|
|
@ -41,26 +47,22 @@ static NodeArg* CastToInt32(Graph& graph, NodeArg* input, ProviderType provider_
|
|||
return &cast32;
|
||||
}
|
||||
|
||||
static NodeArg* CheckInput(Graph& graph, NodeArg* input, ProviderType provider_type, const logging::Logger& logger) {
|
||||
static bool CheckInput(NodeArg* input, const logging::Logger& logger) {
|
||||
// Validate input shape (batch_size, sequence_length) and data type.
|
||||
// Note that batch_size and sequence_length could be symbolic.
|
||||
const TensorShapeProto* input_shape = input->Shape();
|
||||
if (input_shape == nullptr || input_shape->dim_size() != 2 || input->Type() == nullptr) {
|
||||
DEBUG_LOG("Mask shape is unknown or not 2D, or data type unknown");
|
||||
return nullptr;
|
||||
DEBUG_LOG("Input shape is unknown or not 2D, or data type unknown");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto data_type = input->TypeAsProto()->tensor_type().elem_type();
|
||||
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64 &&
|
||||
data_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) {
|
||||
DEBUG_LOG("Input data type is not int32 or int64");
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
|
||||
return CastToInt32(graph, input, provider_type);
|
||||
}
|
||||
return input;
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -124,8 +126,11 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
continue;
|
||||
}
|
||||
// The first input of segment_gather_node must be 2d.
|
||||
auto sg_shape = segment_gather_node.MutableInputDefs()[0]->Shape();
|
||||
if (sg_shape != nullptr && sg_shape->dim_size() != 2) {
|
||||
NodeArg* segment_embedding = segment_gather_node.MutableInputDefs()[0];
|
||||
auto sg_shape = segment_embedding->Shape();
|
||||
if (sg_shape == nullptr || sg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(sg_shape->dim()[1]) ||
|
||||
sg_shape->dim()[1].dim_value() <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -142,8 +147,11 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
continue;
|
||||
}
|
||||
// The first input of word_gather_node must be 2d.
|
||||
auto wg_shape = word_gather_node.MutableInputDefs()[0]->Shape();
|
||||
if (wg_shape != nullptr && wg_shape->dim_size() != 2) {
|
||||
NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0];
|
||||
auto wg_shape = word_embedding->Shape();
|
||||
if (wg_shape == nullptr || wg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(wg_shape->dim()[1]) ||
|
||||
wg_shape->dim()[1].dim_value() != sg_shape->dim()[1].dim_value()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -160,51 +168,157 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
continue;
|
||||
}
|
||||
// The first input of position_gather_node must be 2d.
|
||||
auto pg_shape = position_gather_node.MutableInputDefs()[0]->Shape();
|
||||
if (pg_shape != nullptr && pg_shape->dim_size() != 2) {
|
||||
NodeArg* position_embedding = position_gather_node.MutableInputDefs()[0];
|
||||
auto pg_shape = position_embedding->Shape();
|
||||
if (pg_shape == nullptr || pg_shape->dim_size() != 2 ||
|
||||
!utils::HasDimValue(pg_shape->dim()[1]) ||
|
||||
pg_shape->dim()[1].dim_value() != sg_shape->dim()[1].dim_value()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Match Shape --> Expand path if needed.
|
||||
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path_symbolic{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 1, "Shape", {1}, kOnnxDomain}};
|
||||
// Check the second input of position gather. If it's not initializer, check for two paths.
|
||||
Node* p_expand_node = nullptr;
|
||||
Node* p_shape_node = nullptr;
|
||||
if (graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) {
|
||||
if (edges[0]->GetNode().GetOutputEdgesCount() == 1 && edges[1]->GetNode().GetOutputEdgesCount() == 1) {
|
||||
p_expand_node = graph.GetNode(edges[0]->GetNode().Index());
|
||||
p_shape_node = graph.GetNode(edges[1]->GetNode().Index());
|
||||
std::vector<const Node::EdgeEnd*> pg_edges;
|
||||
bool isValidEmbedSubNode = true;
|
||||
if (graph_utils::IsConstantInitializer(graph, position_gather_node.MutableInputDefs()[1]->Name())) {
|
||||
// Check if the second input of position gather is a tensor with values evenly spaced by 1 starting from 0.
|
||||
std::vector<int64_t> data;
|
||||
auto expected_shape = word_gather_node.MutableInputDefs()[1]->Shape();
|
||||
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(position_gather_node.MutableInputDefs()[1]), data)
|
||||
|| !utils::HasDimValue(expected_shape->dim()[0])
|
||||
|| !utils::HasDimValue(expected_shape->dim()[1])
|
||||
|| static_cast<int>(data.size()) != expected_shape->dim()[0].dim_value() * expected_shape->dim()[1].dim_value()) {
|
||||
continue;
|
||||
}
|
||||
int64_t expected_value = 0;
|
||||
for (size_t i = 0; i < data.size(); i++) {
|
||||
if (data[i] != expected_value) {
|
||||
isValidEmbedSubNode = false;
|
||||
break;
|
||||
}
|
||||
expected_value++;
|
||||
if (expected_value >= static_cast<int64_t>(expected_shape->dim()[1].dim_value())) {
|
||||
expected_value = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Match two paths.
|
||||
// Match Shape --> Expand path if needed.
|
||||
std::vector<NodeIndex> position_parent_nodes;
|
||||
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path_symbolic{
|
||||
{0, 1, "Expand", {8}, kOnnxDomain},
|
||||
{0, 1, "Shape", {1}, kOnnxDomain}};
|
||||
if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) {
|
||||
continue;
|
||||
}
|
||||
if (edges[0]->GetNode().GetOutputEdgesCount() != 1 && edges[1]->GetNode().GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
p_expand_node = graph.GetNode(edges[0]->GetNode().Index());
|
||||
p_shape_node = graph.GetNode(edges[1]->GetNode().Index());
|
||||
// Match Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand
|
||||
Node& expand_node = *graph.GetNode(edges[0]->GetNode().Index());
|
||||
Node& shape_node_1 = *graph.GetNode(edges[1]->GetNode().Index());
|
||||
std::vector<graph_utils::EdgeEndToMatch> pg_parent_path{
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Cast", {9}, kOnnxDomain},
|
||||
{0, 0, "Squeeze", {1}, kOnnxDomain},
|
||||
{0, 0, "Transpose", {1}, kOnnxDomain},
|
||||
{0, 0, "NonZero", {9}, kOnnxDomain},
|
||||
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain},
|
||||
};
|
||||
if (!graph_utils::FindPath(expand_node, true, pg_parent_path, pg_edges, logger)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 0; i < pg_edges.size(); i++) {
|
||||
if (pg_edges[i]->GetNode().GetOutputEdgesCount() != 1) {
|
||||
isValidEmbedSubNode = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Check if the second input of the Gather node in the path has a constant input of 1
|
||||
Node& gather_node = *graph.GetNode(pg_edges[pg_edges.size() - 2]->GetNode().Index());
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(gather_node.InputDefs()[1]), int64_t(1), true)) {
|
||||
DEBUG_LOG("Second input of Gather should be a constant with value 1. ");
|
||||
|
||||
continue;
|
||||
}
|
||||
// Check if the two paths of position gather lead to the same input.
|
||||
Node& shape_node_2 = *graph.GetNode(pg_edges[pg_edges.size() - 1]->GetNode().Index());
|
||||
if (shape_node_1.MutableInputDefs()[0] != shape_node_2.MutableInputDefs()[0]) {
|
||||
continue;
|
||||
}
|
||||
// Check if the parent of "shape" is the parent of "word gather"
|
||||
if (shape_node_1.MutableInputDefs()[0] != word_gather_node.MutableInputDefs()[1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
}
|
||||
if (!isValidEmbedSubNode) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get input "input_ids" from node.
|
||||
NodeArg* input_ids = CheckInput(graph, word_gather_node.MutableInputDefs()[1], layer_norm_node.GetExecutionProviderType(), logger);
|
||||
if (input_ids == nullptr) {
|
||||
NodeArg* input_ids = word_gather_node.MutableInputDefs()[1];
|
||||
if (!CheckInput(input_ids, logger)) {
|
||||
DEBUG_LOG("Input id is not valid. ");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get input "segment_ids" from node.
|
||||
NodeArg* segment_ids = CheckInput(graph, segment_gather_node.MutableInputDefs()[1], layer_norm_node.GetExecutionProviderType(), logger);
|
||||
if (segment_ids == nullptr) {
|
||||
NodeArg* segment_ids = segment_gather_node.MutableInputDefs()[1];
|
||||
if (!CheckInput(segment_ids, logger)) {
|
||||
DEBUG_LOG("Segment id is not valid. ");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get input "mask" from "ReduceSum" node.
|
||||
NodeArg* mask = CheckInput(graph, reduce_sum_node.MutableInputDefs()[0], layer_norm_node.GetExecutionProviderType(), logger);
|
||||
if (mask == nullptr) {
|
||||
NodeArg* mask = reduce_sum_node.MutableInputDefs()[0];
|
||||
if (!CheckInput(mask, logger)) {
|
||||
DEBUG_LOG("Mask is not valid. ");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
|
||||
utils::GetTensorShapeFromTensorShapeProto(*(segment_ids->Shape()))) {
|
||||
DEBUG_LOG("Input_ids and segment id should have the same shape. ");
|
||||
continue;
|
||||
}
|
||||
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
|
||||
utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) {
|
||||
DEBUG_LOG("Input_ids and mask should have the same shape. ");
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeArg* gamma = layer_norm_node.MutableInputDefs()[1];
|
||||
NodeArg* beta = layer_norm_node.MutableInputDefs()[2];
|
||||
if (gamma->Shape() == nullptr
|
||||
|| gamma->Shape()->dim()[0].dim_value() != word_embedding->Shape()->dim()[1].dim_value()) {
|
||||
DEBUG_LOG("Gamma should be of shape (hidden_size). ");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (beta->Shape() == nullptr
|
||||
|| beta->Shape()->dim()[0].dim_value() != word_embedding->Shape()->dim()[1].dim_value()) {
|
||||
DEBUG_LOG("Beta should be of shape (hidden_size). ");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Cast input_ids, segment_ids, and mask to int32 if needed.
|
||||
input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType());
|
||||
segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType());
|
||||
mask = CastToInt32(graph, mask, layer_norm_node.GetExecutionProviderType());
|
||||
|
||||
const std::vector<NodeArg*> embed_layer_norm_input_defs{
|
||||
input_ids,
|
||||
segment_ids,
|
||||
word_gather_node.MutableInputDefs()[0],
|
||||
position_gather_node.MutableInputDefs()[0],
|
||||
segment_gather_node.MutableInputDefs()[0],
|
||||
word_embedding,
|
||||
position_embedding,
|
||||
segment_embedding,
|
||||
layer_norm_node.MutableInputDefs()[1],
|
||||
layer_norm_node.MutableInputDefs()[2],
|
||||
mask};
|
||||
|
|
@ -222,31 +336,10 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
// move output definitions and output edges to embed_layer_norm_node.
|
||||
// remove all the other nodes.
|
||||
std::vector<NodeIndex> nodes_to_remove;
|
||||
for (size_t i = 0; i < pg_edges.size(); i++) {
|
||||
nodes_to_remove.push_back(pg_edges[i]->GetNode().Index());
|
||||
}
|
||||
if (p_shape_node != nullptr && p_expand_node != nullptr) {
|
||||
// Match Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand
|
||||
if (p_expand_node != nullptr) {
|
||||
Node& expand_node = *graph.GetNode(p_expand_node->Index());
|
||||
std::vector<graph_utils::EdgeEndToMatch> expand_parent_path{
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Cast", {9}, kOnnxDomain},
|
||||
{0, 0, "Squeeze", {1}, kOnnxDomain},
|
||||
{0, 0, "Transpose", {1}, kOnnxDomain},
|
||||
{0, 0, "NonZero", {9}, kOnnxDomain},
|
||||
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
|
||||
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Gather", {1, 11}, kOnnxDomain},
|
||||
{0, 0, "Shape", {1}, kOnnxDomain},
|
||||
};
|
||||
if (graph_utils::FindPath(expand_node, true, expand_parent_path, edges, logger)) {
|
||||
for (size_t i = 0; i < edges.size(); i++) {
|
||||
if (edges[i]->GetNode().GetOutputEdgesCount() != 1) {
|
||||
nodes_to_remove.clear();
|
||||
break;
|
||||
}
|
||||
nodes_to_remove.push_back(edges[i]->GetNode().Index());
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes_to_remove.push_back(p_shape_node->Index());
|
||||
nodes_to_remove.push_back(p_expand_node->Index());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1317,6 +1317,11 @@ TEST(GraphTransformationTests, EmbedLayerNormFusionFormat2) {
|
|||
ASSERT_TRUE(op_to_count["Shape"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Expand"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Gather"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
|
||||
ASSERT_TRUE(op_to_count["ConstantOfShape"] == 0);
|
||||
ASSERT_TRUE(op_to_count["NonZero"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Transpose"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Squeeze"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["ReduceSum"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Attention"] == 1);
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Loading…
Reference in a new issue