Support EmbedLayerNorm fusion for DistilBert (#4928)

* checkin embedlayernorm fusion for distilbert

* move function from optimizer_utils

* review comments
This commit is contained in:
Ye Wang 2020-08-26 21:46:31 -07:00 committed by GitHub
parent 00fe718264
commit 792ed44537
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 475 additions and 228 deletions

View file

@ -77,6 +77,16 @@ static void AddNodes(std::vector<NodeIndex>& node_indices,
}
}
static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Node::NodeConstIterator end, const std::vector<std::string>& expected_types) {
for (const std::string& expected_type : expected_types) {
if (start == end || (*start).OpType().compare(expected_type) != 0) {
return false;
}
++start;
}
return start == end;
}
/** Match subgraph like the following:
(input_ids)
/ \
@ -184,7 +194,7 @@ static bool MatchInputToConcatSubgraph(
}
/** Match subgraph like the following:
*
*
* Shape -> ^Gather (indice=0)^ -> ^Unsqueeze^
* / | +-----------------------+
* / v | |
@ -196,7 +206,7 @@ static bool MatchInputToConcatSubgraph(
* # Unsqueeze -> ConstantOfShape -> NonZero -> Transpose -> Squeeze -> (Cast) -> Unsqueeze #
* # or #
* # (Cast (to=7)) -> Range (start=0, delta=1) -> Unsqueeze #
*
*
* Note that position gather node is the node in the bottom of above sub-graph.
* Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found.
* Path in ** is an alternative path to check.
@ -213,43 +223,43 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
// Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze
// --> Cast --> Unsqueeze --> Expand --> Gather
std::vector<graph_utils::EdgeEndToMatch> parent_path_1{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Cast", {9}, kOnnxDomain},
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
{0, 0, "Transpose", {1}, kOnnxDomain},
{0, 0, "NonZero", {9}, kOnnxDomain},
{0, 1, "Expand", {8, 13}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Cast", {9, 13}, kOnnxDomain},
{0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Transpose", {1, 13}, kOnnxDomain},
{0, 0, "NonZero", {9, 13}, kOnnxDomain},
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
// Look for Path 2 (Path 1 with no cast):
std::vector<graph_utils::EdgeEndToMatch> parent_path_2{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Squeeze", {1, 11}, kOnnxDomain},
{0, 0, "Transpose", {1}, kOnnxDomain},
{0, 0, "NonZero", {9}, kOnnxDomain},
{0, 1, "Expand", {8, 13}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Transpose", {1, 13}, kOnnxDomain},
{0, 0, "NonZero", {9, 13}, kOnnxDomain},
{0, 0, "ConstantOfShape", {9}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
// Path 3 Pattern:
// Shape -> Gather -> Cast (to=7) -> Range (start=0, delta=1) -> Unsqueeze -> Expand
std::vector<graph_utils::EdgeEndToMatch> parent_path_3{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 1, "Expand", {8, 13}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Range", {1, 11}, kOnnxDomain},
{0, 1, "Cast", {9}, kOnnxDomain},
{0, 0, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
{0, 1, "Cast", {9, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
// Path 4 pattern (Path 3 with no "Cast"):
std::vector<graph_utils::EdgeEndToMatch> parent_path_4{
{0, 1, "Expand", {8}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
{0, 1, "Expand", {8, 13}, kOnnxDomain},
{0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain},
{0, 0, "Range", {1, 11}, kOnnxDomain},
{0, 1, "Gather", {1, 11}, kOnnxDomain},
{0, 0, "Shape", {1}, kOnnxDomain}};
{0, 1, "Gather", {1, 11, 13}, kOnnxDomain},
{0, 0, "Shape", {1, 13}, kOnnxDomain}};
// Match one of the three path patterns.
if (!graph_utils::FindPath(position_gather_node, true, parent_path_1, pg_edges, logger) &&
!graph_utils::FindPath(position_gather_node, true, parent_path_2, pg_edges, logger) &&
@ -317,8 +327,8 @@ static bool MatchPositionEmbeddingSubgraphsFromGather(
std::vector<const Node::EdgeEnd*> pg_edges_2;
std::vector<graph_utils::EdgeEndToMatch> path_to_match_1{
{0, 1, "Where", {9}, kOnnxDomain},
{0, 0, "Equal", {1, 11}, kOnnxDomain},
{0, 0, "Reshape", {5}, kOnnxDomain}};
{0, 0, "Equal", {1, 7, 11, 13}, kOnnxDomain},
{0, 0, "Reshape", {5, 13}, kOnnxDomain}};
if (graph_utils::FindPath(expand_node, true, path_to_match_1, pg_edges_2, logger)) {
if (!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[0]->GetNode(), 1) ||
!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[1]->GetNode(), 1) ||
@ -361,7 +371,7 @@ static bool MatchPositionEmbeddingSubgraph(
// Constant folding removes Shape and Expand nodes when input has static shape.
// In that case just look for Gather --> Add.
std::vector<const Node::EdgeEnd*> edges;
if (!graph_utils::FindPath(add_node, true, {{0, 1, "Gather", {1, 11}, kOnnxDomain}}, edges, logger)) {
if (!graph_utils::FindPath(add_node, true, {{0, 1, "Gather", {1, 11, 13}, kOnnxDomain}}, edges, logger)) {
return false;
}
Node& position_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
@ -461,6 +471,339 @@ static NodeArg* ExtractEmbedding(Graph& graph,
return &node_arg;
}
static void CreateEmbedLayernormNode(Graph& graph,
NodeArg* input_ids,
NodeArg* segment_ids,
NodeArg* word_embedding,
NodeArg* position_embedding,
NodeArg* segment_embedding,
NodeArg* mask,
Node& layer_norm_node,
Node& reduce_sum_node) {
// Cast input_ids, segment_ids, and mask to int32 if needed.
input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType());
if (segment_ids != nullptr && segment_embedding != nullptr) {
segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType());
}
mask = CastToInt32(graph, mask, layer_norm_node.GetExecutionProviderType());
NodeArg place_holder("", nullptr);
if (segment_ids == nullptr && segment_embedding == nullptr) {
segment_ids = &place_holder;
segment_embedding = &place_holder;
}
const std::vector<NodeArg*> embed_layer_norm_input_defs{
input_ids,
segment_ids,
word_embedding,
position_embedding,
segment_embedding,
layer_norm_node.MutableInputDefs()[1],
layer_norm_node.MutableInputDefs()[2],
mask};
Node& embed_layer_norm_node = graph.AddNode(graph.GenerateNodeName("EmbedLayerNormalization"),
"EmbedLayerNormalization",
"fused EmbedLayerNorm subgraphs ",
embed_layer_norm_input_defs,
{layer_norm_node.MutableOutputDefs()[0], reduce_sum_node.MutableOutputDefs()[0]},
{}, kMSDomain);
// Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value
// will be used.
NodeAttributes ln_attrs = layer_norm_node.GetAttributes();
NodeAttributes::const_iterator epsilon = ln_attrs.find("epsilon");
if (epsilon != ln_attrs.end()) {
embed_layer_norm_node.AddAttribute("epsilon", epsilon->second);
} else {
embed_layer_norm_node.AddAttribute("epsilon", contrib::kDefaultEmbedLayerNormEpsilon);
}
// Assign provider to this new node. Provider should be same as the provider for old node.
embed_layer_norm_node.SetExecutionProviderType(layer_norm_node.GetExecutionProviderType());
}
static bool FuseSubGraph(Graph& graph,
Node& layer_norm_add_node,
Node& layer_norm_node,
Node& reduce_sum_node,
bool& modified,
const logging::Logger& logger) {
// Trace back to find the Gather for segment embedding.
std::vector<graph_utils::EdgeEndToMatch> segment_embedding_path{
{0, 1, "Gather", {1, 11, 13}, kOnnxDomain}};
std::vector<const Node::EdgeEnd*> edges;
if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) {
return false;
}
Node& segment_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
if (!optimizer_utils::CheckOutputEdges(graph, segment_gather_node, 1)) {
return false;
}
// The first input of segment_gather_node must be 2d.
NodeArg* segment_embedding = segment_gather_node.MutableInputDefs()[0];
auto sg_shape = segment_embedding->Shape();
if (sg_shape == nullptr || sg_shape->dim_size() != 2 ||
!utils::HasDimValue(sg_shape->dim()[1]) ||
sg_shape->dim()[1].dim_value() <= 0) {
return false;
}
auto hidden_size = sg_shape->dim()[1].dim_value();
// Trace back to find Gather --> Add --> LayerNormalization
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
{0, 0, "Add", {7, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) {
return false;
}
Node& add_node = *graph.GetNode(edges[0]->GetNode().Index());
Node& word_gather_node = *graph.GetNode(edges[1]->GetNode().Index());
if (!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
!optimizer_utils::CheckOutputEdges(graph, word_gather_node, 1)) {
return false;
}
// The first input of word_gather_node must be 2d.
NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0];
auto wg_shape = word_embedding->Shape();
if (wg_shape == nullptr || wg_shape->dim_size() != 2 ||
!utils::HasDimValue(wg_shape->dim()[1]) ||
wg_shape->dim()[1].dim_value() != hidden_size) {
DEBUG_LOG("Word embedding shape not expected.");
return false;
}
NodeArg* input_ids = word_gather_node.MutableInputDefs()[1];
NodeArg* position_embedding = nullptr;
std::vector<NodeIndex> nodes_to_remove;
// ORT constant folding might be applied to position embedding subgraph when input has static shape.
// Here we handle such special case that the input of add node is constant initializer.
auto add_input_name = add_node.MutableInputDefs()[1]->Name();
if (graph_utils::IsConstantInitializer(graph, add_input_name)) {
// Check that input has static shape.
auto input_shape = input_ids->Shape();
if (input_shape->dim_size() != 2 ||
!utils::HasDimValue(input_shape->dim()[0]) ||
!utils::HasDimValue(input_shape->dim()[1])) {
DEBUG_LOG("Input is expected to have dim value in all dimensions.");
return false;
}
int64_t batch_size = input_shape->dim()[0].dim_value();
int64_t sequence_length = input_shape->dim()[1].dim_value();
if (batch_size <= 0 || sequence_length <= 0) {
return false;
}
const ONNX_NAMESPACE::TensorProto* position_embed_tensor;
if (!graph.GetInitializedTensor(add_input_name, position_embed_tensor)) {
DEBUG_LOG("Failed to get initializer tensor.");
return false;
}
// Tensor shape shall be [batch_size, sequence_length, hidden_size].
if (position_embed_tensor->dims_size() != 3 ||
position_embed_tensor->dims(0) != batch_size ||
position_embed_tensor->dims(1) != sequence_length ||
position_embed_tensor->dims(2) != hidden_size) {
DEBUG_LOG("Position embedding shape not matched.");
return false;
}
// Tensor data type should be float or float16.
const auto data_type = position_embed_tensor->data_type();
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
DEBUG_LOG("Position embedding data type shall be float or float16.");
return false;
}
// The tensor has same data for all batches, and we extract only one batch data as position embedding.
position_embedding = ExtractEmbedding(graph, batch_size, sequence_length, hidden_size, position_embed_tensor, modified);
} else {
if (!MatchPositionEmbeddingSubgraph(graph, add_node, input_ids, logger, nodes_to_remove, position_embedding)) {
DEBUG_LOG("Failed to match position embedding subgraph.");
return false;
}
}
if (position_embedding == nullptr) {
DEBUG_LOG("Failed to get position embedding weights.");
return false;
}
auto pg_shape = position_embedding->Shape();
if (pg_shape == nullptr || pg_shape->dim_size() != 2 ||
!utils::HasDimValue(pg_shape->dim()[1]) ||
pg_shape->dim()[1].dim_value() != hidden_size) {
DEBUG_LOG("Position embedding shape is not expected.");
return false;
}
// Get input "input_ids" from node.
if (!CheckInput(input_ids, logger)) {
DEBUG_LOG("Input id is not valid. ");
return false;
}
// Get input "segment_ids" from node.
NodeArg* segment_ids = segment_gather_node.MutableInputDefs()[1];
if (!CheckInput(segment_ids, logger)) {
DEBUG_LOG("Segment id is not valid. ");
return false;
}
// Get input "mask" from "ReduceSum" node.
NodeArg* mask = reduce_sum_node.MutableInputDefs()[0];
if (!CheckInput(mask, logger)) {
DEBUG_LOG("Mask is not valid. ");
return false;
}
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
utils::GetTensorShapeFromTensorShapeProto(*(segment_ids->Shape()))) {
DEBUG_LOG("Input_ids and segment id should have the same shape. ");
return false;
}
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) {
DEBUG_LOG("Input_ids and mask should have the same shape. ");
return false;
}
NodeArg* gamma = layer_norm_node.MutableInputDefs()[1];
NodeArg* beta = layer_norm_node.MutableInputDefs()[2];
if (gamma->Shape() == nullptr || gamma->Shape()->dim()[0].dim_value() != hidden_size) {
DEBUG_LOG("Gamma should be of shape (hidden_size). ");
return false;
}
if (beta->Shape() == nullptr || beta->Shape()->dim()[0].dim_value() != hidden_size) {
DEBUG_LOG("Beta should be of shape (hidden_size). ");
return false;
}
CreateEmbedLayernormNode(graph, input_ids, segment_ids, word_embedding, position_embedding, segment_embedding,
mask, layer_norm_node, reduce_sum_node);
nodes_to_remove.push_back(word_gather_node.Index());
nodes_to_remove.push_back(segment_gather_node.Index());
nodes_to_remove.push_back(add_node.Index());
nodes_to_remove.push_back(reduce_sum_node.Index());
nodes_to_remove.push_back(layer_norm_add_node.Index());
nodes_to_remove.push_back(layer_norm_node.Index());
for (const auto& index : nodes_to_remove) {
Node* node = graph.GetNode(index);
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
}
return true;
}
// DistilBert's pattern does not have segment embedding
static bool FuseSubGraphDistilBert(Graph& graph,
Node& layer_norm_add_node,
Node& layer_norm_node,
Node& reduce_sum_node,
const logging::Logger& logger) {
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
std::vector<const Node::EdgeEnd*> edges;
if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) {
return false;
}
Node& word_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
if (!optimizer_utils::CheckOutputEdges(graph, word_gather_node, 1)) {
return false;
}
// The first input of word_gather_node must be 2d.
NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0];
auto wg_shape = word_embedding->Shape();
if (wg_shape == nullptr || wg_shape->dim_size() != 2 ||
!utils::HasDimValue(wg_shape->dim()[1])) {
DEBUG_LOG("Word embedding shape not expected.");
return false;
}
int64_t hidden_size = wg_shape->dim()[1].dim_value();
Node& add_node = layer_norm_add_node;
NodeArg* input_ids = word_gather_node.MutableInputDefs()[1];
NodeArg* position_embedding = nullptr;
std::vector<NodeIndex> nodes_to_remove;
// ORT constant folding might be applied to position embedding subgraph when input has static shape.
// Here we handle such special case that the input of add node is constant initializer.
auto add_input_name = add_node.MutableInputDefs()[1]->Name();
if (!MatchPositionEmbeddingSubgraph(graph, add_node, input_ids, logger, nodes_to_remove, position_embedding)) {
DEBUG_LOG("Failed to match position embedding subgraph.");
return false;
}
if (position_embedding == nullptr) {
DEBUG_LOG("Failed to get position embedding weights.");
return false;
}
auto pg_shape = position_embedding->Shape();
if (pg_shape == nullptr || pg_shape->dim_size() != 2 ||
!utils::HasDimValue(pg_shape->dim()[1]) ||
pg_shape->dim()[1].dim_value() != hidden_size) {
DEBUG_LOG("Position embedding shape is not expected.");
return false;
}
// Get input "input_ids" from node.
if (!CheckInput(input_ids, logger)) {
DEBUG_LOG("Input id is not valid. ");
return false;
}
// Get input "mask" from "ReduceSum" node.
NodeArg* mask = reduce_sum_node.MutableInputDefs()[0];
if (!CheckInput(mask, logger)) {
DEBUG_LOG("Mask is not valid. ");
return false;
}
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) {
DEBUG_LOG("Input_ids and mask should have the same shape. ");
return false;
}
NodeArg* gamma = layer_norm_node.MutableInputDefs()[1];
NodeArg* beta = layer_norm_node.MutableInputDefs()[2];
if (gamma->Shape() == nullptr || gamma->Shape()->dim()[0].dim_value() != hidden_size) {
DEBUG_LOG("Gamma should be of shape (hidden_size). ");
return false;
}
if (beta->Shape() == nullptr || beta->Shape()->dim()[0].dim_value() != hidden_size) {
DEBUG_LOG("Beta should be of shape (hidden_size). ");
return false;
}
CreateEmbedLayernormNode(graph, input_ids, nullptr, word_embedding, position_embedding, nullptr,
mask, layer_norm_node, reduce_sum_node);
nodes_to_remove.push_back(word_gather_node.Index());
nodes_to_remove.push_back(add_node.Index());
nodes_to_remove.push_back(reduce_sum_node.Index());
nodes_to_remove.push_back(layer_norm_node.Index());
for (const auto& index : nodes_to_remove) {
Node* node = graph.GetNode(index);
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
}
return true;
}
/**
Embed Layer Normalization will fuse embeddings and mask processing into one node :
The embeddings before conversion:
@ -512,208 +855,17 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
}
Node& layer_norm_add_node = *graph.GetNode(edges[0]->GetNode().Index());
// Trace back to find the Gather for segment embedding.
std::vector<graph_utils::EdgeEndToMatch> segment_embedding_path{
{0, 1, "Gather", {1, 11, 13}, kOnnxDomain}};
if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) {
continue;
}
Node& segment_gather_node = *graph.GetNode(edges[0]->GetNode().Index());
if (!optimizer_utils::CheckOutputEdges(graph, segment_gather_node, 1)) {
continue;
}
// The first input of segment_gather_node must be 2d.
NodeArg* segment_embedding = segment_gather_node.MutableInputDefs()[0];
auto sg_shape = segment_embedding->Shape();
if (sg_shape == nullptr || sg_shape->dim_size() != 2 ||
!utils::HasDimValue(sg_shape->dim()[1]) ||
sg_shape->dim()[1].dim_value() <= 0) {
continue;
}
auto hidden_size = sg_shape->dim()[1].dim_value();
// Trace back to find Gather --> Add --> LayerNormalization
std::vector<graph_utils::EdgeEndToMatch> word_embedding_path{
{0, 0, "Add", {7, 13}, kOnnxDomain},
{0, 0, "Gather", {1, 11, 13}, kOnnxDomain}};
if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) {
continue;
}
Node& add_node = *graph.GetNode(edges[0]->GetNode().Index());
Node& word_gather_node = *graph.GetNode(edges[1]->GetNode().Index());
if (!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
!optimizer_utils::CheckOutputEdges(graph, word_gather_node, 1)) {
continue;
}
// The first input of word_gather_node must be 2d.
NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0];
auto wg_shape = word_embedding->Shape();
if (wg_shape == nullptr || wg_shape->dim_size() != 2 ||
!utils::HasDimValue(wg_shape->dim()[1]) ||
wg_shape->dim()[1].dim_value() != hidden_size) {
DEBUG_LOG("Word embedding shape not expected.");
continue;
}
NodeArg* input_ids = word_gather_node.MutableInputDefs()[1];
NodeArg* position_embedding = nullptr;
std::vector<NodeIndex> nodes_to_remove;
// ORT constant folding might be applied to position embedding subgraph when input has static shape.
// Here we handle such special case that the input of add node is constant initializer.
auto add_input_name = add_node.MutableInputDefs()[1]->Name();
if (graph_utils::IsConstantInitializer(graph, add_input_name)) {
// Check that input has static shape.
auto input_shape = input_ids->Shape();
if (input_shape->dim_size() != 2 ||
!utils::HasDimValue(input_shape->dim()[0]) ||
!utils::HasDimValue(input_shape->dim()[1])) {
DEBUG_LOG("Input is expected to have dim value in all dimensions.");
continue;
if (IsNeighborNodeExpectedTypes(layer_norm_add_node.InputEdgesBegin(), layer_norm_add_node.InputNodesEnd(), {"Gather", "Gather"})) {
//DistilBert
if (FuseSubGraphDistilBert(graph, layer_norm_add_node, layer_norm_node, reduce_sum_node, logger)) {
modified = true;
}
int64_t batch_size = input_shape->dim()[0].dim_value();
int64_t sequence_length = input_shape->dim()[1].dim_value();
if (batch_size <= 0 || sequence_length <= 0) {
continue;
}
const ONNX_NAMESPACE::TensorProto* position_embed_tensor;
if (!graph.GetInitializedTensor(add_input_name, position_embed_tensor)) {
DEBUG_LOG("Failed to get initializer tensor.");
continue;
}
// Tensor shape shall be [batch_size, sequence_length, hidden_size].
if (position_embed_tensor->dims_size() != 3 ||
position_embed_tensor->dims(0) != batch_size ||
position_embed_tensor->dims(1) != sequence_length ||
position_embed_tensor->dims(2) != hidden_size) {
DEBUG_LOG("Position embedding shape not matched.");
continue;
}
// Tensor data type should be float or float16.
const auto data_type = position_embed_tensor->data_type();
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
DEBUG_LOG("Position embedding data type shall be float or float16.");
continue;
}
// The tensor has same data for all batches, and we extract only one batch data as position embedding.
position_embedding = ExtractEmbedding(graph, batch_size, sequence_length, hidden_size, position_embed_tensor, modified);
} else {
if (!MatchPositionEmbeddingSubgraph(graph, add_node, input_ids, logger, nodes_to_remove, position_embedding)) {
DEBUG_LOG("Failed to match position embedding subgraph.");
continue;
if (FuseSubGraph(graph, layer_norm_add_node, layer_norm_node, reduce_sum_node, modified, logger)) {
modified = true;
}
}
if (position_embedding == nullptr) {
DEBUG_LOG("Failed to get position embedding weights.");
continue;
}
auto pg_shape = position_embedding->Shape();
if (pg_shape == nullptr || pg_shape->dim_size() != 2 ||
!utils::HasDimValue(pg_shape->dim()[1]) ||
pg_shape->dim()[1].dim_value() != hidden_size) {
DEBUG_LOG("Position embedding shape is not expected.");
continue;
}
// Get input "input_ids" from node.
if (!CheckInput(input_ids, logger)) {
DEBUG_LOG("Input id is not valid. ");
continue;
}
// Get input "segment_ids" from node.
NodeArg* segment_ids = segment_gather_node.MutableInputDefs()[1];
if (!CheckInput(segment_ids, logger)) {
DEBUG_LOG("Segment id is not valid. ");
continue;
}
// Get input "mask" from "ReduceSum" node.
NodeArg* mask = reduce_sum_node.MutableInputDefs()[0];
if (!CheckInput(mask, logger)) {
DEBUG_LOG("Mask is not valid. ");
continue;
}
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
utils::GetTensorShapeFromTensorShapeProto(*(segment_ids->Shape()))) {
DEBUG_LOG("Input_ids and segment id should have the same shape. ");
continue;
}
if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) !=
utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) {
DEBUG_LOG("Input_ids and mask should have the same shape. ");
continue;
}
NodeArg* gamma = layer_norm_node.MutableInputDefs()[1];
NodeArg* beta = layer_norm_node.MutableInputDefs()[2];
if (gamma->Shape() == nullptr || gamma->Shape()->dim()[0].dim_value() != hidden_size) {
DEBUG_LOG("Gamma should be of shape (hidden_size). ");
continue;
}
if (beta->Shape() == nullptr || beta->Shape()->dim()[0].dim_value() != hidden_size) {
DEBUG_LOG("Beta should be of shape (hidden_size). ");
continue;
}
// Cast input_ids, segment_ids, and mask to int32 if needed.
input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType());
segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType());
mask = CastToInt32(graph, mask, layer_norm_node.GetExecutionProviderType());
const std::vector<NodeArg*> embed_layer_norm_input_defs{
input_ids,
segment_ids,
word_embedding,
position_embedding,
segment_embedding,
layer_norm_node.MutableInputDefs()[1],
layer_norm_node.MutableInputDefs()[2],
mask};
Node& embed_layer_norm_node = graph.AddNode(graph.GenerateNodeName("EmbedLayerNormalization"),
"EmbedLayerNormalization",
"fused EmbedLayerNorm subgraphs ",
embed_layer_norm_input_defs,
{layer_norm_node.MutableOutputDefs()[0], reduce_sum_node.MutableOutputDefs()[0]},
{}, kMSDomain);
// Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value
// will be used.
NodeAttributes ln_attrs = layer_norm_node.GetAttributes();
NodeAttributes::const_iterator epsilon = ln_attrs.find("epsilon");
if (epsilon != ln_attrs.end()) {
embed_layer_norm_node.AddAttribute("epsilon", epsilon->second);
} else {
embed_layer_norm_node.AddAttribute("epsilon", contrib::kDefaultEmbedLayerNormEpsilon);
}
// Assign provider to this new node. Provider should be same as the provider for old node.
embed_layer_norm_node.SetExecutionProviderType(layer_norm_node.GetExecutionProviderType());
nodes_to_remove.push_back(word_gather_node.Index());
nodes_to_remove.push_back(segment_gather_node.Index());
nodes_to_remove.push_back(add_node.Index());
nodes_to_remove.push_back(reduce_sum_node.Index());
nodes_to_remove.push_back(layer_norm_add_node.Index());
nodes_to_remove.push_back(layer_norm_node.Index());
for (const auto& index : nodes_to_remove) {
Node* node = graph.GetNode(index);
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
}
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -2611,6 +2611,28 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) {
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
}
//DistilBert
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<EmbedLayerNormFusion>(), TransformerLevel::Level2);
auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
ASSERT_TRUE(ret1.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1);
EXPECT_EQ(op_to_count["Attention"], 1);
EXPECT_EQ(op_to_count["Cast"], 2);
EXPECT_EQ(op_to_count["Shape"], 0);
EXPECT_EQ(op_to_count["Gather"], 0);
EXPECT_EQ(op_to_count["Unsqueeze"], 0);
EXPECT_EQ(op_to_count["ReduceSum"], 0);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx";
std::shared_ptr<Model> p_model;

View file

@ -278,9 +278,82 @@ def GenerateModel6(model_name):
model = helper.make_model(graph)
onnx.save(model, model_name)
def GenerateModel7(model_name):
batch_size = 2
hidden_size = 4
attention_heads = 2
sequence_length = 3
nodes = [
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0),
helper.make_node("Shape", ["input_ids"], ["shape0_out"], "shape0"),
helper.make_node("Gather", ["shape0_out", "indices_1"], ["gather0_out"], "gather0"),
helper.make_node("Range", ["start", "gather0_out", "delta"], ["range0_out"], "range0"),
helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
helper.make_node("Expand", ["unsqueeze0_out", "shape1_out"], ["expand_out"], "expand"),
helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather", axis=0),
helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["add1_out"], "add1"),
helper.make_node("LayerNormalization", ["add1_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"],
"layernorm",
axis=-1,
epsion=0.000009999999747378752),
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
"att",
domain="com.microsoft",
num_heads=attention_heads),
helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"),
helper.make_node("Add", ["matmul_out", "add_bias"], ["add2_out"], "add2"),
helper.make_node("Add", ["add2_out", "layernorm_out"], ["add3_out"], "add3")
]
qkv_weights = [
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
]
initializers = [ # initializers
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
helper.make_tensor('start', TensorProto.INT64, [], [0]),
helper.make_tensor('delta', TensorProto.INT64, [], [1]),
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights),
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size],
[0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
]
graph = helper.make_graph(
nodes,
"EmbedLayerNorm_format7", #name
[ # inputs
helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]),
helper.make_tensor_value_info('input_mask', TensorProto.INT64, [batch_size, sequence_length]),
],
[ # outputs
helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]),
],
initializers)
model = helper.make_model(graph)
onnx.save(model, model_name)
GenerateModel3('embed_layer_norm_format3.onnx', True)
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
GenerateModel5('embed_layer_norm_format5.onnx')
GenerateModel6('embed_layer_norm_format6.onnx')
GenerateModel7('embed_layer_norm_format7.onnx') #distilbert
GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')