diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 079caf405e..f6751fa8cb 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -77,6 +77,16 @@ static void AddNodes(std::vector& node_indices, } } +static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Node::NodeConstIterator end, const std::vector& expected_types) { + for (const std::string& expected_type : expected_types) { + if (start == end || (*start).OpType().compare(expected_type) != 0) { + return false; + } + ++start; + } + return start == end; +} + /** Match subgraph like the following: (input_ids) / \ @@ -184,7 +194,7 @@ static bool MatchInputToConcatSubgraph( } /** Match subgraph like the following: - * + * * Shape -> ^Gather (indice=0)^ -> ^Unsqueeze^ * / | +-----------------------+ * / v | | @@ -196,7 +206,7 @@ static bool MatchInputToConcatSubgraph( * # Unsqueeze -> ConstantOfShape -> NonZero -> Transpose -> Squeeze -> (Cast) -> Unsqueeze # * # or # * # (Cast (to=7)) -> Range (start=0, delta=1) -> Unsqueeze # - * + * * Note that position gather node is the node in the bottom of above sub-graph. * Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found. * Path in ** is an alternative path to check. @@ -213,43 +223,43 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( // Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze // --> Cast --> Unsqueeze --> Expand --> Gather std::vector parent_path_1{ - {0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Cast", {9}, kOnnxDomain}, - {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, - {0, 0, "Transpose", {1}, kOnnxDomain}, - {0, 0, "NonZero", {9}, kOnnxDomain}, + {0, 1, "Expand", {8, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Cast", {9, 13}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Transpose", {1, 13}, kOnnxDomain}, + {0, 0, "NonZero", {9, 13}, kOnnxDomain}, {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}}; + {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Shape", {1, 13}, kOnnxDomain}}; // Look for Path 2 (Path 1 with no cast): std::vector parent_path_2{ - {0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, - {0, 0, "Transpose", {1}, kOnnxDomain}, - {0, 0, "NonZero", {9}, kOnnxDomain}, + {0, 1, "Expand", {8, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Transpose", {1, 13}, kOnnxDomain}, + {0, 0, "NonZero", {9, 13}, kOnnxDomain}, {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}}; + {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Shape", {1, 13}, kOnnxDomain}}; // Path 3 Pattern: // Shape -> Gather -> Cast (to=7) -> Range (start=0, delta=1) -> Unsqueeze -> Expand std::vector parent_path_3{ - {0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 1, "Expand", {8, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, {0, 0, "Range", {1, 11}, kOnnxDomain}, - {0, 1, "Cast", {9}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}}; + {0, 1, "Cast", {9, 13}, kOnnxDomain}, + {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Shape", {1, 13}, kOnnxDomain}}; // Path 4 pattern (Path 3 with no "Cast"): std::vector parent_path_4{ - {0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 1, "Expand", {8, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, {0, 0, "Range", {1, 11}, kOnnxDomain}, - {0, 1, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}}; + {0, 1, "Gather", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Shape", {1, 13}, kOnnxDomain}}; // Match one of the three path patterns. if (!graph_utils::FindPath(position_gather_node, true, parent_path_1, pg_edges, logger) && !graph_utils::FindPath(position_gather_node, true, parent_path_2, pg_edges, logger) && @@ -317,8 +327,8 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( std::vector pg_edges_2; std::vector path_to_match_1{ {0, 1, "Where", {9}, kOnnxDomain}, - {0, 0, "Equal", {1, 11}, kOnnxDomain}, - {0, 0, "Reshape", {5}, kOnnxDomain}}; + {0, 0, "Equal", {1, 7, 11, 13}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13}, kOnnxDomain}}; if (graph_utils::FindPath(expand_node, true, path_to_match_1, pg_edges_2, logger)) { if (!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[0]->GetNode(), 1) || !optimizer_utils::CheckOutputEdges(graph, pg_edges_2[1]->GetNode(), 1) || @@ -361,7 +371,7 @@ static bool MatchPositionEmbeddingSubgraph( // Constant folding removes Shape and Expand nodes when input has static shape. // In that case just look for Gather --> Add. std::vector edges; - if (!graph_utils::FindPath(add_node, true, {{0, 1, "Gather", {1, 11}, kOnnxDomain}}, edges, logger)) { + if (!graph_utils::FindPath(add_node, true, {{0, 1, "Gather", {1, 11, 13}, kOnnxDomain}}, edges, logger)) { return false; } Node& position_gather_node = *graph.GetNode(edges[0]->GetNode().Index()); @@ -461,6 +471,339 @@ static NodeArg* ExtractEmbedding(Graph& graph, return &node_arg; } +static void CreateEmbedLayernormNode(Graph& graph, + NodeArg* input_ids, + NodeArg* segment_ids, + NodeArg* word_embedding, + NodeArg* position_embedding, + NodeArg* segment_embedding, + NodeArg* mask, + Node& layer_norm_node, + Node& reduce_sum_node) { + // Cast input_ids, segment_ids, and mask to int32 if needed. + input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType()); + if (segment_ids != nullptr && segment_embedding != nullptr) { + segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType()); + } + mask = CastToInt32(graph, mask, layer_norm_node.GetExecutionProviderType()); + + NodeArg place_holder("", nullptr); + if (segment_ids == nullptr && segment_embedding == nullptr) { + segment_ids = &place_holder; + segment_embedding = &place_holder; + } + + const std::vector embed_layer_norm_input_defs{ + input_ids, + segment_ids, + word_embedding, + position_embedding, + segment_embedding, + layer_norm_node.MutableInputDefs()[1], + layer_norm_node.MutableInputDefs()[2], + mask}; + + Node& embed_layer_norm_node = graph.AddNode(graph.GenerateNodeName("EmbedLayerNormalization"), + "EmbedLayerNormalization", + "fused EmbedLayerNorm subgraphs ", + embed_layer_norm_input_defs, + {layer_norm_node.MutableOutputDefs()[0], reduce_sum_node.MutableOutputDefs()[0]}, + {}, kMSDomain); + + // Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value + // will be used. + NodeAttributes ln_attrs = layer_norm_node.GetAttributes(); + NodeAttributes::const_iterator epsilon = ln_attrs.find("epsilon"); + if (epsilon != ln_attrs.end()) { + embed_layer_norm_node.AddAttribute("epsilon", epsilon->second); + } else { + embed_layer_norm_node.AddAttribute("epsilon", contrib::kDefaultEmbedLayerNormEpsilon); + } + + // Assign provider to this new node. Provider should be same as the provider for old node. + embed_layer_norm_node.SetExecutionProviderType(layer_norm_node.GetExecutionProviderType()); +} + +static bool FuseSubGraph(Graph& graph, + Node& layer_norm_add_node, + Node& layer_norm_node, + Node& reduce_sum_node, + bool& modified, + const logging::Logger& logger) { + // Trace back to find the Gather for segment embedding. + std::vector segment_embedding_path{ + {0, 1, "Gather", {1, 11, 13}, kOnnxDomain}}; + std::vector edges; + if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) { + return false; + } + Node& segment_gather_node = *graph.GetNode(edges[0]->GetNode().Index()); + if (!optimizer_utils::CheckOutputEdges(graph, segment_gather_node, 1)) { + return false; + } + // The first input of segment_gather_node must be 2d. + NodeArg* segment_embedding = segment_gather_node.MutableInputDefs()[0]; + auto sg_shape = segment_embedding->Shape(); + if (sg_shape == nullptr || sg_shape->dim_size() != 2 || + !utils::HasDimValue(sg_shape->dim()[1]) || + sg_shape->dim()[1].dim_value() <= 0) { + return false; + } + auto hidden_size = sg_shape->dim()[1].dim_value(); + + // Trace back to find Gather --> Add --> LayerNormalization + std::vector word_embedding_path{ + {0, 0, "Add", {7, 13}, kOnnxDomain}, + {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}}; + if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) { + return false; + } + Node& add_node = *graph.GetNode(edges[0]->GetNode().Index()); + Node& word_gather_node = *graph.GetNode(edges[1]->GetNode().Index()); + if (!optimizer_utils::CheckOutputEdges(graph, add_node, 1) || + !optimizer_utils::CheckOutputEdges(graph, word_gather_node, 1)) { + return false; + } + // The first input of word_gather_node must be 2d. + NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0]; + auto wg_shape = word_embedding->Shape(); + if (wg_shape == nullptr || wg_shape->dim_size() != 2 || + !utils::HasDimValue(wg_shape->dim()[1]) || + wg_shape->dim()[1].dim_value() != hidden_size) { + DEBUG_LOG("Word embedding shape not expected."); + return false; + } + + NodeArg* input_ids = word_gather_node.MutableInputDefs()[1]; + NodeArg* position_embedding = nullptr; + std::vector nodes_to_remove; + + // ORT constant folding might be applied to position embedding subgraph when input has static shape. + // Here we handle such special case that the input of add node is constant initializer. + auto add_input_name = add_node.MutableInputDefs()[1]->Name(); + if (graph_utils::IsConstantInitializer(graph, add_input_name)) { + // Check that input has static shape. + auto input_shape = input_ids->Shape(); + if (input_shape->dim_size() != 2 || + !utils::HasDimValue(input_shape->dim()[0]) || + !utils::HasDimValue(input_shape->dim()[1])) { + DEBUG_LOG("Input is expected to have dim value in all dimensions."); + return false; + } + + int64_t batch_size = input_shape->dim()[0].dim_value(); + int64_t sequence_length = input_shape->dim()[1].dim_value(); + if (batch_size <= 0 || sequence_length <= 0) { + return false; + } + + const ONNX_NAMESPACE::TensorProto* position_embed_tensor; + if (!graph.GetInitializedTensor(add_input_name, position_embed_tensor)) { + DEBUG_LOG("Failed to get initializer tensor."); + return false; + } + // Tensor shape shall be [batch_size, sequence_length, hidden_size]. + if (position_embed_tensor->dims_size() != 3 || + position_embed_tensor->dims(0) != batch_size || + position_embed_tensor->dims(1) != sequence_length || + position_embed_tensor->dims(2) != hidden_size) { + DEBUG_LOG("Position embedding shape not matched."); + return false; + } + + // Tensor data type should be float or float16. + const auto data_type = position_embed_tensor->data_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + DEBUG_LOG("Position embedding data type shall be float or float16."); + return false; + } + + // The tensor has same data for all batches, and we extract only one batch data as position embedding. + position_embedding = ExtractEmbedding(graph, batch_size, sequence_length, hidden_size, position_embed_tensor, modified); + } else { + if (!MatchPositionEmbeddingSubgraph(graph, add_node, input_ids, logger, nodes_to_remove, position_embedding)) { + DEBUG_LOG("Failed to match position embedding subgraph."); + return false; + } + } + + if (position_embedding == nullptr) { + DEBUG_LOG("Failed to get position embedding weights."); + return false; + } + + auto pg_shape = position_embedding->Shape(); + if (pg_shape == nullptr || pg_shape->dim_size() != 2 || + !utils::HasDimValue(pg_shape->dim()[1]) || + pg_shape->dim()[1].dim_value() != hidden_size) { + DEBUG_LOG("Position embedding shape is not expected."); + return false; + } + + // Get input "input_ids" from node. + if (!CheckInput(input_ids, logger)) { + DEBUG_LOG("Input id is not valid. "); + return false; + } + + // Get input "segment_ids" from node. + NodeArg* segment_ids = segment_gather_node.MutableInputDefs()[1]; + if (!CheckInput(segment_ids, logger)) { + DEBUG_LOG("Segment id is not valid. "); + return false; + } + + // Get input "mask" from "ReduceSum" node. + NodeArg* mask = reduce_sum_node.MutableInputDefs()[0]; + if (!CheckInput(mask, logger)) { + DEBUG_LOG("Mask is not valid. "); + return false; + } + + if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) != + utils::GetTensorShapeFromTensorShapeProto(*(segment_ids->Shape()))) { + DEBUG_LOG("Input_ids and segment id should have the same shape. "); + return false; + } + if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) != + utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) { + DEBUG_LOG("Input_ids and mask should have the same shape. "); + return false; + } + + NodeArg* gamma = layer_norm_node.MutableInputDefs()[1]; + NodeArg* beta = layer_norm_node.MutableInputDefs()[2]; + if (gamma->Shape() == nullptr || gamma->Shape()->dim()[0].dim_value() != hidden_size) { + DEBUG_LOG("Gamma should be of shape (hidden_size). "); + return false; + } + + if (beta->Shape() == nullptr || beta->Shape()->dim()[0].dim_value() != hidden_size) { + DEBUG_LOG("Beta should be of shape (hidden_size). "); + return false; + } + + CreateEmbedLayernormNode(graph, input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, + mask, layer_norm_node, reduce_sum_node); + + nodes_to_remove.push_back(word_gather_node.Index()); + nodes_to_remove.push_back(segment_gather_node.Index()); + nodes_to_remove.push_back(add_node.Index()); + nodes_to_remove.push_back(reduce_sum_node.Index()); + nodes_to_remove.push_back(layer_norm_add_node.Index()); + nodes_to_remove.push_back(layer_norm_node.Index()); + + for (const auto& index : nodes_to_remove) { + Node* node = graph.GetNode(index); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node->Index()); + } + + return true; +} + +// DistilBert's pattern does not have segment embedding +static bool FuseSubGraphDistilBert(Graph& graph, + Node& layer_norm_add_node, + Node& layer_norm_node, + Node& reduce_sum_node, + const logging::Logger& logger) { + std::vector word_embedding_path{ + {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}}; + std::vector edges; + if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) { + return false; + } + Node& word_gather_node = *graph.GetNode(edges[0]->GetNode().Index()); + if (!optimizer_utils::CheckOutputEdges(graph, word_gather_node, 1)) { + return false; + } + // The first input of word_gather_node must be 2d. + NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0]; + auto wg_shape = word_embedding->Shape(); + if (wg_shape == nullptr || wg_shape->dim_size() != 2 || + !utils::HasDimValue(wg_shape->dim()[1])) { + DEBUG_LOG("Word embedding shape not expected."); + return false; + } + + int64_t hidden_size = wg_shape->dim()[1].dim_value(); + + Node& add_node = layer_norm_add_node; + + NodeArg* input_ids = word_gather_node.MutableInputDefs()[1]; + NodeArg* position_embedding = nullptr; + std::vector nodes_to_remove; + + // ORT constant folding might be applied to position embedding subgraph when input has static shape. + // Here we handle such special case that the input of add node is constant initializer. + auto add_input_name = add_node.MutableInputDefs()[1]->Name(); + if (!MatchPositionEmbeddingSubgraph(graph, add_node, input_ids, logger, nodes_to_remove, position_embedding)) { + DEBUG_LOG("Failed to match position embedding subgraph."); + return false; + } + + if (position_embedding == nullptr) { + DEBUG_LOG("Failed to get position embedding weights."); + return false; + } + + auto pg_shape = position_embedding->Shape(); + if (pg_shape == nullptr || pg_shape->dim_size() != 2 || + !utils::HasDimValue(pg_shape->dim()[1]) || + pg_shape->dim()[1].dim_value() != hidden_size) { + DEBUG_LOG("Position embedding shape is not expected."); + return false; + } + + // Get input "input_ids" from node. + if (!CheckInput(input_ids, logger)) { + DEBUG_LOG("Input id is not valid. "); + return false; + } + + // Get input "mask" from "ReduceSum" node. + NodeArg* mask = reduce_sum_node.MutableInputDefs()[0]; + if (!CheckInput(mask, logger)) { + DEBUG_LOG("Mask is not valid. "); + return false; + } + + if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) != + utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) { + DEBUG_LOG("Input_ids and mask should have the same shape. "); + return false; + } + + NodeArg* gamma = layer_norm_node.MutableInputDefs()[1]; + NodeArg* beta = layer_norm_node.MutableInputDefs()[2]; + if (gamma->Shape() == nullptr || gamma->Shape()->dim()[0].dim_value() != hidden_size) { + DEBUG_LOG("Gamma should be of shape (hidden_size). "); + return false; + } + + if (beta->Shape() == nullptr || beta->Shape()->dim()[0].dim_value() != hidden_size) { + DEBUG_LOG("Beta should be of shape (hidden_size). "); + return false; + } + + CreateEmbedLayernormNode(graph, input_ids, nullptr, word_embedding, position_embedding, nullptr, + mask, layer_norm_node, reduce_sum_node); + + nodes_to_remove.push_back(word_gather_node.Index()); + nodes_to_remove.push_back(add_node.Index()); + nodes_to_remove.push_back(reduce_sum_node.Index()); + nodes_to_remove.push_back(layer_norm_node.Index()); + + for (const auto& index : nodes_to_remove) { + Node* node = graph.GetNode(index); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node->Index()); + } + + return true; +} /** Embed Layer Normalization will fuse embeddings and mask processing into one node : The embeddings before conversion: @@ -512,208 +855,17 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l } Node& layer_norm_add_node = *graph.GetNode(edges[0]->GetNode().Index()); - // Trace back to find the Gather for segment embedding. - std::vector segment_embedding_path{ - {0, 1, "Gather", {1, 11, 13}, kOnnxDomain}}; - if (!graph_utils::FindPath(layer_norm_add_node, true, segment_embedding_path, edges, logger)) { - continue; - } - Node& segment_gather_node = *graph.GetNode(edges[0]->GetNode().Index()); - if (!optimizer_utils::CheckOutputEdges(graph, segment_gather_node, 1)) { - continue; - } - // The first input of segment_gather_node must be 2d. - NodeArg* segment_embedding = segment_gather_node.MutableInputDefs()[0]; - auto sg_shape = segment_embedding->Shape(); - if (sg_shape == nullptr || sg_shape->dim_size() != 2 || - !utils::HasDimValue(sg_shape->dim()[1]) || - sg_shape->dim()[1].dim_value() <= 0) { - continue; - } - auto hidden_size = sg_shape->dim()[1].dim_value(); - - // Trace back to find Gather --> Add --> LayerNormalization - std::vector word_embedding_path{ - {0, 0, "Add", {7, 13}, kOnnxDomain}, - {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}}; - if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) { - continue; - } - Node& add_node = *graph.GetNode(edges[0]->GetNode().Index()); - Node& word_gather_node = *graph.GetNode(edges[1]->GetNode().Index()); - if (!optimizer_utils::CheckOutputEdges(graph, add_node, 1) || - !optimizer_utils::CheckOutputEdges(graph, word_gather_node, 1)) { - continue; - } - // The first input of word_gather_node must be 2d. - NodeArg* word_embedding = word_gather_node.MutableInputDefs()[0]; - auto wg_shape = word_embedding->Shape(); - if (wg_shape == nullptr || wg_shape->dim_size() != 2 || - !utils::HasDimValue(wg_shape->dim()[1]) || - wg_shape->dim()[1].dim_value() != hidden_size) { - DEBUG_LOG("Word embedding shape not expected."); - continue; - } - - NodeArg* input_ids = word_gather_node.MutableInputDefs()[1]; - NodeArg* position_embedding = nullptr; - std::vector nodes_to_remove; - - // ORT constant folding might be applied to position embedding subgraph when input has static shape. - // Here we handle such special case that the input of add node is constant initializer. - auto add_input_name = add_node.MutableInputDefs()[1]->Name(); - if (graph_utils::IsConstantInitializer(graph, add_input_name)) { - // Check that input has static shape. - auto input_shape = input_ids->Shape(); - if (input_shape->dim_size() != 2 || - !utils::HasDimValue(input_shape->dim()[0]) || - !utils::HasDimValue(input_shape->dim()[1])) { - DEBUG_LOG("Input is expected to have dim value in all dimensions."); - continue; + if (IsNeighborNodeExpectedTypes(layer_norm_add_node.InputEdgesBegin(), layer_norm_add_node.InputNodesEnd(), {"Gather", "Gather"})) { + //DistilBert + if (FuseSubGraphDistilBert(graph, layer_norm_add_node, layer_norm_node, reduce_sum_node, logger)) { + modified = true; } - - int64_t batch_size = input_shape->dim()[0].dim_value(); - int64_t sequence_length = input_shape->dim()[1].dim_value(); - if (batch_size <= 0 || sequence_length <= 0) { - continue; - } - - const ONNX_NAMESPACE::TensorProto* position_embed_tensor; - if (!graph.GetInitializedTensor(add_input_name, position_embed_tensor)) { - DEBUG_LOG("Failed to get initializer tensor."); - continue; - } - // Tensor shape shall be [batch_size, sequence_length, hidden_size]. - if (position_embed_tensor->dims_size() != 3 || - position_embed_tensor->dims(0) != batch_size || - position_embed_tensor->dims(1) != sequence_length || - position_embed_tensor->dims(2) != hidden_size) { - DEBUG_LOG("Position embedding shape not matched."); - continue; - } - - // Tensor data type should be float or float16. - const auto data_type = position_embed_tensor->data_type(); - if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - DEBUG_LOG("Position embedding data type shall be float or float16."); - continue; - } - - // The tensor has same data for all batches, and we extract only one batch data as position embedding. - position_embedding = ExtractEmbedding(graph, batch_size, sequence_length, hidden_size, position_embed_tensor, modified); } else { - if (!MatchPositionEmbeddingSubgraph(graph, add_node, input_ids, logger, nodes_to_remove, position_embedding)) { - DEBUG_LOG("Failed to match position embedding subgraph."); - continue; + if (FuseSubGraph(graph, layer_norm_add_node, layer_norm_node, reduce_sum_node, modified, logger)) { + modified = true; } } - - if (position_embedding == nullptr) { - DEBUG_LOG("Failed to get position embedding weights."); - continue; - } - - auto pg_shape = position_embedding->Shape(); - if (pg_shape == nullptr || pg_shape->dim_size() != 2 || - !utils::HasDimValue(pg_shape->dim()[1]) || - pg_shape->dim()[1].dim_value() != hidden_size) { - DEBUG_LOG("Position embedding shape is not expected."); - continue; - } - - // Get input "input_ids" from node. - if (!CheckInput(input_ids, logger)) { - DEBUG_LOG("Input id is not valid. "); - continue; - } - - // Get input "segment_ids" from node. - NodeArg* segment_ids = segment_gather_node.MutableInputDefs()[1]; - if (!CheckInput(segment_ids, logger)) { - DEBUG_LOG("Segment id is not valid. "); - continue; - } - - // Get input "mask" from "ReduceSum" node. - NodeArg* mask = reduce_sum_node.MutableInputDefs()[0]; - if (!CheckInput(mask, logger)) { - DEBUG_LOG("Mask is not valid. "); - continue; - } - - if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) != - utils::GetTensorShapeFromTensorShapeProto(*(segment_ids->Shape()))) { - DEBUG_LOG("Input_ids and segment id should have the same shape. "); - continue; - } - if (utils::GetTensorShapeFromTensorShapeProto(*(input_ids->Shape())) != - utils::GetTensorShapeFromTensorShapeProto(*(mask->Shape()))) { - DEBUG_LOG("Input_ids and mask should have the same shape. "); - continue; - } - - NodeArg* gamma = layer_norm_node.MutableInputDefs()[1]; - NodeArg* beta = layer_norm_node.MutableInputDefs()[2]; - if (gamma->Shape() == nullptr || gamma->Shape()->dim()[0].dim_value() != hidden_size) { - DEBUG_LOG("Gamma should be of shape (hidden_size). "); - continue; - } - - if (beta->Shape() == nullptr || beta->Shape()->dim()[0].dim_value() != hidden_size) { - DEBUG_LOG("Beta should be of shape (hidden_size). "); - continue; - } - - // Cast input_ids, segment_ids, and mask to int32 if needed. - input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType()); - segment_ids = CastToInt32(graph, segment_ids, layer_norm_node.GetExecutionProviderType()); - mask = CastToInt32(graph, mask, layer_norm_node.GetExecutionProviderType()); - - const std::vector embed_layer_norm_input_defs{ - input_ids, - segment_ids, - word_embedding, - position_embedding, - segment_embedding, - layer_norm_node.MutableInputDefs()[1], - layer_norm_node.MutableInputDefs()[2], - mask}; - Node& embed_layer_norm_node = graph.AddNode(graph.GenerateNodeName("EmbedLayerNormalization"), - "EmbedLayerNormalization", - "fused EmbedLayerNorm subgraphs ", - embed_layer_norm_input_defs, - {layer_norm_node.MutableOutputDefs()[0], reduce_sum_node.MutableOutputDefs()[0]}, - {}, kMSDomain); - - // Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value - // will be used. - NodeAttributes ln_attrs = layer_norm_node.GetAttributes(); - NodeAttributes::const_iterator epsilon = ln_attrs.find("epsilon"); - if (epsilon != ln_attrs.end()) { - embed_layer_norm_node.AddAttribute("epsilon", epsilon->second); - } else { - embed_layer_norm_node.AddAttribute("epsilon", contrib::kDefaultEmbedLayerNormEpsilon); - } - - // Assign provider to this new node. Provider should be same as the provider for old node. - embed_layer_norm_node.SetExecutionProviderType(layer_norm_node.GetExecutionProviderType()); - - nodes_to_remove.push_back(word_gather_node.Index()); - nodes_to_remove.push_back(segment_gather_node.Index()); - nodes_to_remove.push_back(add_node.Index()); - nodes_to_remove.push_back(reduce_sum_node.Index()); - nodes_to_remove.push_back(layer_norm_add_node.Index()); - nodes_to_remove.push_back(layer_norm_node.Index()); - - for (const auto& index : nodes_to_remove) { - Node* node = graph.GetNode(index); - graph_utils::RemoveNodeOutputEdges(graph, *node); - graph.RemoveNode(node->Index()); - } - modified = true; } - return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 3e0694d3a7..a2dd1e9953 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2611,6 +2611,28 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); } +//DistilBert +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { + auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret1 = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret1.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 0); +} + TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) { auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx new file mode 100644 index 0000000000..2d10dc96a8 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py index af9953ac62..9fcd169f3a 100644 --- a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py @@ -278,9 +278,82 @@ def GenerateModel6(model_name): model = helper.make_model(graph) onnx.save(model, model_name) +def GenerateModel7(model_name): + batch_size = 2 + hidden_size = 4 + attention_heads = 2 + sequence_length = 3 + + nodes = [ + helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0), + + helper.make_node("Shape", ["input_ids"], ["shape0_out"], "shape0"), + helper.make_node("Gather", ["shape0_out", "indices_1"], ["gather0_out"], "gather0"), + helper.make_node("Range", ["start", "gather0_out", "delta"], ["range0_out"], "range0"), + helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), + helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"), + helper.make_node("Expand", ["unsqueeze0_out", "shape1_out"], ["expand_out"], "expand"), + helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather", axis=0), + + helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["add1_out"], "add1"), + helper.make_node("LayerNormalization", ["add1_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752), + + helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6), + + helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), + helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"], + "att", + domain="com.microsoft", + num_heads=attention_heads), + helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"), + helper.make_node("Add", ["matmul_out", "add_bias"], ["add2_out"], "add2"), + helper.make_node("Add", ["add2_out", "layernorm_out"], ["add3_out"], "add3") + ] + + qkv_weights = [ + 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, + 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, + 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0 + ] + + initializers = [ # initializers + helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), + helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), + helper.make_tensor('start', TensorProto.INT64, [], [0]), + helper.make_tensor('delta', TensorProto.INT64, [], [1]), + helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights), + helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size], + [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + ] + + graph = helper.make_graph( + nodes, + "EmbedLayerNorm_format7", #name + [ # inputs + helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info('input_mask', TensorProto.INT64, [batch_size, sequence_length]), + ], + [ # outputs + helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]), + ], + initializers) + + model = helper.make_model(graph) + onnx.save(model, model_name) GenerateModel3('embed_layer_norm_format3.onnx', True) GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False) GenerateModel5('embed_layer_norm_format5.onnx') GenerateModel6('embed_layer_norm_format6.onnx') +GenerateModel7('embed_layer_norm_format7.onnx') #distilbert GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')