EmbedLayerNormalization Fusion Improvement (#2553)

Embedding layer norm fusion improvements - add more checks
This commit is contained in:
liuziyue 2019-12-07 23:14:26 -08:00 committed by Tianlei Wu
parent 0f12346d76
commit 200f4b4ea6
4 changed files with 154 additions and 56 deletions

View file

@ -3,6 +3,8 @@
#include "core/optimizer/initializer.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"
#include "core/framework/tensorprotoutils.h"
#include "float.h"
#define DEBUG_LOG(x) LOGS(logger, VERBOSE) << x
@ -13,6 +15,10 @@ namespace onnxruntime {
// Add a Cast to convert Input from int64 to int32.
static NodeArg* CastToInt32(Graph& graph, NodeArg* input, ProviderType provider_type) {
auto data_type = input->TypeAsProto()->tensor_type().elem_type();
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
return input;
}
const TensorShapeProto* input_shape = input->Shape();
TypeProto input_int32;
input_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
@ -41,26 +47,22 @@ static NodeArg* CastToInt32(Graph& graph, NodeArg* input, ProviderType provider_
return &cast32;
}
static NodeArg* CheckInput(Graph& graph, NodeArg* input, ProviderType provider_type, const logging::Logger& logger) {
static bool CheckInput(NodeArg* input, const logging::Logger& logger) {
// Validate input shape (batch_size, sequence_length) and data type.
// Note that batch_size and sequence_length could be symbolic.
const TensorShapeProto* input_shape = input->Shape();
if (input_shape == nullptr || input_shape->dim_size() != 2 || input->Type() == nullptr) {
DEBUG_LOG("Mask shape is unknown or not 2D, or data type unknown");
return nullptr;
DEBUG_LOG("Input shape is unknown or not 2D, or data type unknown");
return false;
}
auto data_type = input->TypeAsProto()->tensor_type().elem_type();
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64 &&
data_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) {
DEBUG_LOG("Input data type is not int32 or int64");
return nullptr;
return false;
}
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
return CastToInt32(graph, input, provider_type);
}
return input;
return true;
}
/**
@ -124,8 +126,11 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
continue;
}
// The first input of segment_gather_node must be 2d.
auto sg_shape = segment_gather_node.MutableInputDefs()[0]->Shape();
if (sg_shape != nullptr && sg_shape->dim_size() != 2) {
NodeArg* segment_embedding = segment_gather_node.MutableInputDefs()[0];
auto sg_shape = segment_embedding->Shape();
if (sg_shape == nullptr || sg_shape->dim_size() != 2 ||
!utils::HasDimValue(sg_shape->dim()[1]) ||
sg_shape->dim()[1].dim_value() <= 0) {
continue;
}
@ -142,8 +147,11 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
continue;
}
// The first input of word_gather_node must be 2d.
auto wg_shape = word_gather_node.MutableInputDefs()[0]->Shape();
if (wg_shape != nullptr && wg_shape->dim_size() != 2) {
NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0];
auto wg_shape = word_embedding->Shape();
if (wg_shape == nullptr || wg_shape->dim_size() != 2 ||
!utils::HasDimValue(wg_shape->dim()[1]) ||
wg_shape->dim()[1].dim_value() != sg_shape->dim()[1].dim_value()) {
continue;
}
@ -160,51 +168,157 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
continue;
}
// The first input of position_gather_node must be 2d.
auto pg_shape = position_gather_node.MutableInputDefs()[0]->Shape();
if (pg_shape != nullptr && pg_shape->dim_size() != 2) {
NodeArg* position_embedding = position_gather_node.MutableInputDefs()[0];
auto pg_shape = position_embedding->Shape();
if (pg_shape == nullptr || pg_shape->dim_size() != 2 ||
!utils::HasDimValue(pg_shape->dim()[1]) ||
pg_shape->dim()[1].dim_value() != sg_shape->dim()[1].dim_value()) {
continue;
}
// Match Shape --> Expand path if needed.
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path_symbolic{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 1, "Shape", {1}, kOnnxDomain}};
// Check the second input of position gather. If it's not initializer, check for two paths.
Node* p_expand_node = nullptr;
Node* p_shape_node = nullptr;
if (graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) {
if (edges[0]->GetNode().GetOutputEdgesCount() == 1 && edges[1]->GetNode().GetOutputEdgesCount() == 1) {
p_expand_node = graph.GetNode(edges[0]->GetNode().Index());
p_shape_node = graph.GetNode(edges[1]->GetNode().Index());
std::vector<const Node::EdgeEnd*> pg_edges;
bool isValidEmbedSubNode = true;
if (graph_utils::IsConstantInitializer(graph, position_gather_node.MutableInputDefs()[1]->Name())) {
// Check if the second input of position gather is a tensor with values evenly spaced by 1 starting from 0.
std::vector<int64_t> data;
auto expected_shape = word_gather_node.MutableInputDefs()[1]->Shape();
if (!optimizer_utils::AppendTensorFromInitializer(graph, *(position_gather_node.MutableInputDefs()[1]), data)
|| !utils::HasDimValue(expected_shape->dim()[0])
|| !utils::HasDimValue(expected_shape->dim()[1])
|| static_cast<int>(data.size()) != expected_shape->dim()[0].dim_value() * expected_shape->dim()[1].dim_value()) {
continue;
}
int64_t expected_value = 0;
for (size_t i = 0; i < data.size(); i++) {
if (data[i] != expected_value) {
isValidEmbedSubNode = false;
break;
}
expected_value++;
if (expected_value >= static_cast<int64_t>(expected_shape->dim()[1].dim_value())) {
expected_value = 0;
}
}
} else {
// Match two paths.
// Match Shape --> Expand path if needed.
std::vector<NodeIndex> position_parent_nodes;
std::vector<graph_utils::EdgeEndToMatch> position_embedding_path_symbolic{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 1, "Shape", {1}, kOnnxDomain}};
if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) {
continue;
}
if (edges[0]->GetNode().GetOutputEdgesCount() != 1 && edges[1]->GetNode().GetOutputEdgesCount() != 1) {
continue;
}
p_expand_node = graph.GetNode(edges[0]->GetNode().Index());
p_shape_node = graph.GetNode(edges[1]->GetNode().Index());
// Match Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand
Node& expand_node = *graph.GetNode(edges[0]->GetNode().Index());
Node& shape_node_1 = *graph.GetNode(edges[1]->GetNode().Index());
std::vector<graph_utils::EdgeEndToMatch> pg_parent_path{
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Cast", {9}, kOnnxDomain},
{0, 0, "Squeeze", {1}, kOnnxDomain},
{0, 0, "Transpose", {1}, kOnnxDomain},
{0, 0, "NonZero", {9}, kOnnxDomain},
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain},
};
if (!graph_utils::FindPath(expand_node, true, pg_parent_path, pg_edges, logger)) {
continue;
}
for (size_t i = 0; i < pg_edges.size(); i++) {
if (pg_edges[i]->GetNode().GetOutputEdgesCount() != 1) {
isValidEmbedSubNode = false;
break;
}
}
// Check if the second input of the Gather node in the path has a constant input of 1
Node& gather_node = *graph.GetNode(pg_edges[pg_edges.size() - 2]->GetNode().Index());
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(gather_node.InputDefs()[1]), int64_t(1), true)) {
DEBUG_LOG("Second input of Gather should be a constant with value 1. ");
continue;
}
// Check if the two paths of position gather lead to the same input.
Node& shape_node_2 = *graph.GetNode(pg_edges[pg_edges.size() - 1]->GetNode().Index());
if (shape_node_1.MutableInputDefs()[0] != shape_node_2.MutableInputDefs()[0]) {
continue;
}
// Check if the parent of "shape" is the parent of "word gather"
if (shape_node_1.MutableInputDefs()[0] != word_gather_node.MutableInputDefs()[1]) {
continue;
}
}
if (!isValidEmbedSubNode) {
continue;
}
// Get input "input_ids" from node.
NodeArg* input_ids = CheckInput(graph, word_gather_node.MutableInputDefs()[1], layer_norm_node.GetExecutionProviderType(), logger);
if (input_ids == nullptr) {
NodeArg* input_ids = word_gather_node.MutableInputDefs()[1];
if (!CheckInput(input_ids, logger)) {
DEBUG_LOG("Input id is not valid. ");
continue;
}
// Get input "segment_ids" from node.
NodeArg* segment_ids = CheckInput(graph, segment_gather_node.MutableInputDefs()[1], layer_norm_node.GetExecutionProviderType(), logger);
if (segment_ids == nullptr) {
NodeArg* segment_ids = segment_gather_node.MutableInputDefs()[1];
if (!CheckInput(segment_ids, logger)) {
DEBUG_LOG("Segment id is not valid. ");
continue;
}
// Get input "mask" from "ReduceSum" node.
NodeArg* mask = CheckInput(graph, reduce_sum_node.MutableInputDefs()[0], layer_norm_node.GetExecutionProviderType(), logger);
if (mask == nullptr) {
NodeArg* mask = reduce_sum_node.MutableInputDefs()[0];
if (!CheckInput(mask, logger)) {
DEBUG_LOG("Mask is not valid. ");
continue;
}
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
utils::GetTensorShapeFromTensorShapeProto(*(segment_ids->Shape()))) {
DEBUG_LOG("Input_ids and segment id should have the same shape. ");
continue;
}
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) {
DEBUG_LOG("Input_ids and mask should have the same shape. ");
continue;
}
NodeArg* gamma = layer_norm_node.MutableInputDefs()[1];
NodeArg* beta = layer_norm_node.MutableInputDefs()[2];
if (gamma->Shape() == nullptr
|| gamma->Shape()->dim()[0].dim_value() != word_embedding->Shape()->dim()[1].dim_value()) {
DEBUG_LOG("Gamma should be of shape (hidden_size). ");
continue;
}
if (beta->Shape() == nullptr
|| beta->Shape()->dim()[0].dim_value() != word_embedding->Shape()->dim()[1].dim_value()) {
DEBUG_LOG("Beta should be of shape (hidden_size). ");
continue;
}
// Cast input_ids, segment_ids, and mask to int32 if needed.
input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType());
segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType());
mask = CastToInt32(graph, mask, layer_norm_node.GetExecutionProviderType());
const std::vector<NodeArg*> embed_layer_norm_input_defs{
input_ids,
segment_ids,
word_gather_node.MutableInputDefs()[0],
position_gather_node.MutableInputDefs()[0],
segment_gather_node.MutableInputDefs()[0],
word_embedding,
position_embedding,
segment_embedding,
layer_norm_node.MutableInputDefs()[1],
layer_norm_node.MutableInputDefs()[2],
mask};
@ -222,31 +336,10 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
// move output definitions and output edges to embed_layer_norm_node.
// remove all the other nodes.
std::vector<NodeIndex> nodes_to_remove;
for (size_t i = 0; i < pg_edges.size(); i++) {
nodes_to_remove.push_back(pg_edges[i]->GetNode().Index());
}
if (p_shape_node != nullptr && p_expand_node != nullptr) {
// Match Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand
if (p_expand_node != nullptr) {
Node& expand_node = *graph.GetNode(p_expand_node->Index());
std::vector<graph_utils::EdgeEndToMatch> expand_parent_path{
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Cast", {9}, kOnnxDomain},
{0, 0, "Squeeze", {1}, kOnnxDomain},
{0, 0, "Transpose", {1}, kOnnxDomain},
{0, 0, "NonZero", {9}, kOnnxDomain},
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain},
};
if (graph_utils::FindPath(expand_node, true, expand_parent_path, edges, logger)) {
for (size_t i = 0; i < edges.size(); i++) {
if (edges[i]->GetNode().GetOutputEdgesCount() != 1) {
nodes_to_remove.clear();
break;
}
nodes_to_remove.push_back(edges[i]->GetNode().Index());
}
}
}
nodes_to_remove.push_back(p_shape_node->Index());
nodes_to_remove.push_back(p_expand_node->Index());
}

View file

@ -1317,6 +1317,11 @@ TEST(GraphTransformationTests, EmbedLayerNormFusionFormat2) {
ASSERT_TRUE(op_to_count["Shape"] == 0);
ASSERT_TRUE(op_to_count["Expand"] == 0);
ASSERT_TRUE(op_to_count["Gather"] == 0);
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
ASSERT_TRUE(op_to_count["ConstantOfShape"] == 0);
ASSERT_TRUE(op_to_count["NonZero"] == 0);
ASSERT_TRUE(op_to_count["Transpose"] == 0);
ASSERT_TRUE(op_to_count["Squeeze"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["ReduceSum"] == 0);
ASSERT_TRUE(op_to_count["Attention"] == 1);