diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index a4370960fe..c567b38d1b 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -753,19 +753,26 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, const std::vec } bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) { - std::queue q; - std::vector nodes_to_remove; - q.push(&start_node); + std::queue q; + std::unordered_set removed_nodes; + q.push(start_node.Index()); + + bool is_start_node(true); // From the current node, remove nodes bottom-up util it reaches a node with multiple outputs/graph output. - while (q.size() != 0) { - const Node& cur_node = *(q.front()); + while (!q.empty()) { + NodeIndex cur_node_index = q.front(); q.pop(); + + if (removed_nodes.find(cur_node_index) != removed_nodes.end()) { + continue; + } // Each eligible node in the subgraph must have less than one output edge and no output should be // the graph output + const Node& cur_node = *graph.GetNode(cur_node_index); if (cur_node.GetOutputEdgesCount() > 1 || !graph.GetNodeOutputsInGraphOutputs(cur_node).empty()) { continue; } - nodes_to_remove.push_back(cur_node.Index()); + // push the parents of current node to the queue. for (unsigned int i = 0; i < cur_node.InputDefs().size(); ++i) { const std::string& input_name = GetNodeInputName(cur_node, i); @@ -773,19 +780,25 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) { // skip initializers and graph inputs continue; } - q.push(GetInputNode(cur_node, i)); + const Node* parent_node = GetInputNode(cur_node, i); + q.push(parent_node->Index()); + } + + if (is_start_node || cur_node.GetOutputEdgesCount() == 0) { + Node* cur_node_p = graph.GetNode(cur_node_index); + RemoveNodeOutputEdges(graph, *cur_node_p); + graph.RemoveNode(cur_node_index); + + removed_nodes.insert(cur_node_index); + is_start_node = false; } } - if (nodes_to_remove.size() <= 0) { + + if (removed_nodes.size() == 0) { // Nothing to remove return false; } - // Remove nodes that are not used anymore. - for (const auto& node_index : nodes_to_remove) { - Node* node = graph.GetNode(node_index); - RemoveNodeOutputEdges(graph, *node); - graph.RemoveNode(node->Index()); - } + return true; } diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 31a88be543..5757100cf1 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -65,18 +65,6 @@ static bool CheckInput(NodeArg* input, const logging::Logger& logger) { return true; } -static void AddNodes(std::vector& node_indices, - const std::vector& edges) { - for (size_t i = 0; i < edges.size(); i++) { - auto item = edges[i]->GetNode().Index(); - // Avoid duplication. - if (std::find(node_indices.begin(), node_indices.end(), item) != node_indices.end()) { - continue; - } - node_indices.push_back(item); - } -} - static bool IsNeighborNodeExpectedTypes(Node::NodeConstIterator start, const Node::NodeConstIterator end, const std::vector& expected_types) { for (const std::string& expected_type : expected_types) { if (start == end || (*start).OpType().compare(expected_type) != 0) { @@ -121,9 +109,7 @@ static bool MatchInputToConcatSubgraph( const NodeArg* input_ids, const int index, const logging::Logger& logger, - std::vector& subgraph_node_indices, const NodeIndex expected_gather_node_1_index) { - subgraph_node_indices.clear(); std::vector expand_parent_path1{ {0, index, "Concat", {4, 11}, kOnnxDomain}, {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, @@ -157,8 +143,6 @@ static bool MatchInputToConcatSubgraph( return false; } - AddNodes(subgraph_node_indices, edges); - std::vector concat_parent_path{ {0, 1, "Unsqueeze", {1, 11}, kOnnxDomain}, {0, 0, "Gather", {1, 11}, kOnnxDomain}, @@ -211,7 +195,6 @@ static bool MatchInputToConcatSubgraph( } } - AddNodes(subgraph_node_indices, edges); return true; } @@ -238,9 +221,7 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( Graph& graph, const Node& position_gather_node, const NodeArg* input_ids, - const logging::Logger& logger, - std::vector& subgraph_node_indices) { - subgraph_node_indices.clear(); + const logging::Logger& logger) { std::vector pg_edges; // Look for Path 1: // Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze @@ -348,8 +329,6 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( DEBUG_LOG("The parent of shape nodes are expected to be input_ids."); return false; } - - subgraph_node_indices.push_back(shape_node_index); } else { // gather_output_edges_count == 2 // Match optional Reshape -> Equal -> Where -> Expand // | | @@ -373,20 +352,17 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( return false; } // Match [input_ids] -> Gather -> Shape -> Unsqueeze from Reshape node. - if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, subgraph_node_indices, gather_node.Index())) { + if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, gather_node.Index())) { DEBUG_LOG("Failed to match position subgraph."); return false; } - AddNodes(subgraph_node_indices, pg_edges_2); - } else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, subgraph_node_indices, gather_node.Index())) { + } else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, gather_node.Index())) { // Match [input_ids] -> Gather -> Shape -> Unsqueeze from Expand node. DEBUG_LOG("Failed to match position subgraph."); return false; } } - AddNodes(subgraph_node_indices, pg_edges); - return true; } @@ -438,11 +414,12 @@ static bool MatchPositionEmbeddingSubgraph( } } } else { - if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) { + if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger)) { return false; } } + subgraph_node_indices.clear(); subgraph_node_indices.push_back(position_gather_node.Index()); return true; } @@ -507,7 +484,6 @@ static void CreateEmbedLayernormNode(Graph& graph, NodeArg* word_embedding, NodeArg* position_embedding, NodeArg* segment_embedding, - Node& layer_norm_node) { // Cast input_ids and segment_ids to int32 if needed. input_ids = CastToInt32(graph, input_ids, layer_norm_node.GetExecutionProviderType()); @@ -528,7 +504,6 @@ static void CreateEmbedLayernormNode(Graph& graph, position_embedding, segment_embedding, layer_norm_node.MutableInputDefs()[1], - layer_norm_node.MutableInputDefs()[2]}; auto& mask_index = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mask_index"), nullptr); @@ -705,6 +680,12 @@ static bool FuseSubGraph(Graph& graph, CreateEmbedLayernormNode(graph, input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, layer_norm_node); + if (!nodes_to_remove.empty()) { + graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *graph.GetNode(nodes_to_remove[0])); + } + + nodes_to_remove.clear(); + nodes_to_remove.push_back(word_gather_node.Index()); nodes_to_remove.push_back(segment_gather_node.Index()); nodes_to_remove.push_back(add_node.Index()); @@ -712,7 +693,7 @@ static bool FuseSubGraph(Graph& graph, nodes_to_remove.push_back(layer_norm_add_node.Index()); nodes_to_remove.push_back(layer_norm_node.Index()); - for (const auto& index : nodes_to_remove) { + for (const NodeIndex index : nodes_to_remove) { Node* node = graph.GetNode(index); graph_utils::RemoveNodeOutputEdges(graph, *node); graph.RemoveNode(node->Index()); @@ -725,7 +706,6 @@ static bool FuseSubGraph(Graph& graph, static bool FuseSubGraphDistilBert(Graph& graph, Node& layer_norm_add_node, Node& layer_norm_node, - const logging::Logger& logger) { std::vector word_embedding_path{ {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}}; @@ -796,12 +776,18 @@ static bool FuseSubGraphDistilBert(Graph& graph, CreateEmbedLayernormNode(graph, input_ids, nullptr, word_embedding, position_embedding, nullptr, layer_norm_node); + if (!nodes_to_remove.empty()) { + graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *graph.GetNode(nodes_to_remove[0])); + } + + nodes_to_remove.clear(); + nodes_to_remove.push_back(word_gather_node.Index()); nodes_to_remove.push_back(add_node.Index()); nodes_to_remove.push_back(layer_norm_node.Index()); - for (const auto& index : nodes_to_remove) { + for (const NodeIndex index : nodes_to_remove) { Node* node = graph.GetNode(index); graph_utils::RemoveNodeOutputEdges(graph, *node); graph.RemoveNode(node->Index()); diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 5b92d149e4..6ceafb2c06 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -264,7 +264,7 @@ def run_tensorflow(use_gpu, model_names, model_class, precision, batch_sizes, se for model_name in model_names: config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) - model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, if_tf_model=True) + model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, is_tf_model=True) tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 8dc8fb468b..0fa9e2d555 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -211,19 +211,16 @@ def modelclass_dispatcher(model_name, custom_model_class): return "AutoModel" -def load_pretrained_model(model_name, config, cache_dir, custom_model_class, if_tf_model=False): +def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_tf_model=False): model_class_name = modelclass_dispatcher(model_name, custom_model_class) - - if model_class_name == "GPT2ModelNoPastState": - return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir) - + if model_class_name == "GPT2ModelNoPastState": if is_tf_model: raise NotImplementedError("TFGPT2ModelNoPastState is currently not supported.") else: return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir) - if if_tf_model: + if is_tf_model: model_class_name = 'TF' + model_class_name transformers_module = __import__("transformers", fromlist=[model_class_name]) @@ -329,7 +326,7 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) - model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, if_tf_model=True) + model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class, is_tf_model=True) model._saved_model_inputs_spec = None diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index f4db4e0f30..5bea40ee75 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2799,7 +2799,7 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) { EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); EXPECT_EQ(op_to_count["Attention"], 1); EXPECT_EQ(op_to_count["Cast"], 2); - EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Shape"], 1); EXPECT_EQ(op_to_count["Gather"], 2); EXPECT_EQ(op_to_count["Unsqueeze"], 2); EXPECT_EQ(op_to_count["ReduceSum"], 1);