From 2092bebc782b69c2a7a973fc76cb8099ad1da94a Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 29 Mar 2024 17:44:38 +0800 Subject: [PATCH] Fix transformer layer detection for recompute (#20106) ### Fix transformer layer detection for recompute Originally logic miss detecting the layer boudary node in Mistral model. This PR simplifies the searching, by using more strong pattern's match, to make sure it is flexible enough to cover different transformer variants. Also add a UT. Add a warning when user enable layerwise recompute but no layer boudary nodes are found. --- .../memory_optimizer/memory_insight.cc | 7 +- .../memory_optimizer/recompute_analysis.cc | 33 +++- .../memory_optimizer/recompute_analysis.h | 2 +- .../memory_optimizer/transformer_specific.cc | 159 +++++++++++++----- .../memory_optimizer/transformer_specific.h | 2 +- .../test/optimizer/memory_optimizer_test.cc | 41 +++++ 6 files changed, 194 insertions(+), 50 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 54c49db059..3d0fa942fd 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -257,10 +257,15 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, is_forward_nodes, logger)); - InlinedHashSet layer_boundary_ln_nodes; + InlinedVector layer_boundary_ln_nodes; FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map, yield_op_order_in_topological_sort, layer_boundary_ln_nodes); + if (probe_config.enable_transformer_layer_as_boundary && layer_boundary_ln_nodes.size() == 0) { + LOGS(logger, WARNING) << "No transformer layer boundary nodes found, this might cause memory optimization " + "not working as expected. Please check the model and the configuration."; + } + // The first pass - find the candidate subgraphs. for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { const Node* p_node = graph_viewer.GetNode(node_ids[i]); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index b421eb2ab3..37ac1c4950 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -386,6 +386,26 @@ const InlinedHashMap& GetAllowedRecompu {1, {}}, }, }, + { + utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have SimplifiedLayerNormalization, + // while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain. + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, { utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), { @@ -691,7 +711,7 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& candidate_output_args_map, - const InlinedHashSet& layer_boundary_ln_nodes, + const InlinedVector& layer_boundary_ln_nodes, const logging::Logger& logger, bool compromise_stashed_activation, bool& can_compromise_stashed_activation) { @@ -709,13 +729,14 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap auto output_name = node.OutputDefs()[output_index]->Name(); auto consumers = graph_viewer.GetConsumerNodes(output_name); for (auto& consumer : consumers) { - if (layer_boundary_ln_nodes.find(consumer) != layer_boundary_ln_nodes.end()) { + if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), consumer) != + layer_boundary_ln_nodes.end()) { int dest_in_index = optimizer_utils::IndexOfNodeInput(*consumer, *node.OutputDefs()[output_index]); if (dest_in_index == 0) { - LOGS(logger, INFO) << "Node " << node.Name() << "(" << node.OpType() - << ") is a Attention+MLP layer boundary node, " - << "its stashed activation outputs are used by LayerNormalization's inputs, " - << "we don't need to recompute it."; + MO_LOG_DEBUG_INFO(logger, "Node " + node.Name() + "(" + node.OpType() + + ") is a Attention+MLP layer boundary node, " + + "its stashed activation outputs are used by LayerNormalization's inputs, " + + "we don't need to recompute it."); return nullptr; } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index ab114d9701..ac1021f5eb 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -164,7 +164,7 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& candidate_output_args_map, - const InlinedHashSet& layer_boundary_ln_nodes, + const InlinedVector& layer_boundary_ln_nodes, const logging::Logger& logger, bool compromise_stashed_activation, bool& can_compromise_stashed_activation); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index c88a0f05d3..3bcfbd324b 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include #include @@ -16,43 +17,139 @@ namespace onnxruntime::optimizer::memory_optimizer { +namespace { + +bool IsLayerNormNode(const Node& node) { + static const std::set layer_norm_ops = { + "LayerNormalization", + "SkipLayerNormalization", + "SimplifiedLayerNormalization", + "SkipSimplifiedLayerNormalization", + }; + return layer_norm_ops.find(node.OpType()) != layer_norm_ops.end(); +} + +bool IsSoftmaxNode(const Node& node) { + static const std::set softmax_ops = { + "Softmax", + "BiasSoftmax", + }; + return softmax_ops.find(node.OpType()) != softmax_ops.end(); +} + +std::tuple IsResidualNodeArg(const GraphViewer& graph_viewer, const NodeArg* node_arg) { + auto consumers = graph_viewer.GetConsumerNodes(node_arg->Name()); + if (2 > consumers.size()) { + return std::make_tuple(false, nullptr, nullptr); + } + + // Find the Add node from the consumer list. + const Node* add_node = nullptr; + const Node* other_node = nullptr; + for (const auto* consumer : consumers) { + if (consumer->OpType() == "Add") { + add_node = consumer; + } else if (IsLayerNormNode(*consumer)) { + other_node = consumer; + } + } + + return std::make_tuple(add_node != nullptr && other_node != nullptr, add_node, other_node); +} +} // namespace + +/* + One classical layer includes 1). input layer norm, 2). attention, 3). residual add + (input layer norm input + attention output), 4). post attention layer norm feedforward, and 5). residual add + (post attention layer norm input + feedforward out). + + The pattern graph looks like below for each transformer layer (taking the example of MistralDecoderLayer): + | + Embedding + | + ----------------------| + | | + | | + | SimplifiedLayerNormalization (layer boudary node) + | | + | | + | MistralAttention + | | + | | + |____________________Add + | + ----------------------| + | | + | | + | SimplifiedLayerNormalization + | | + | | + | MultipleLayerPerception + | | + | | + |____________________Add + | + (new layer) + ----------------------| + | | + | SimplifiedLayerNormalization + .... +*/ void FindLayerBoundaryLayerNormNodes( const GraphViewer& graph_viewer, - const logging::Logger&, + const logging::Logger& logger, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const ptrdiff_t& yield_op_order_in_topological_sort, - InlinedHashSet& layer_boundary_ln_nodes) { + InlinedVector& layer_boundary_ln_nodes) { // Loop all nodes to find LayerNormalization nodes. // For each LayerNormalization node, keep checking its output nodes, // until find a node that is Softmax or BiasSoftmax or another LayerNormalization. // If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION. // If the found node is another LayerNormalization, the LayerNormalization node as MLP. - const InlinedHashSet softmax_ops{"Softmax", "BiasSoftmax"}; - const InlinedHashSet layernorm_ops{"LayerNormalization", "SkipLayerNormalization"}; layer_boundary_ln_nodes.clear(); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); for (auto node_index : node_topology_list) { auto& node = *graph_viewer.GetNode(node_index); - if (layernorm_ops.find(node.OpType()) == layernorm_ops.end()) { + if (!IsLayerNormNode(node)) { + continue; + } + const NodeArg* input_arg = node.InputDefs()[0]; + + // IsResidualNodeArg checks input_arg + auto [is_residual_node_arg, add_node, other_node] = IsResidualNodeArg(graph_viewer, input_arg); + if (!is_residual_node_arg) { + MO_LOG_DEBUG_INFO(logger, "Not a residual node arg " + input_arg->Name()); continue; } + // At this point, there should not be any recompute node, so we don't need check the node existence in + // node_index_to_its_order_in_topological_sort_map. + ptrdiff_t attention_residual_add_node_order = + node_index_to_its_order_in_topological_sort_map.at(add_node->Index()); + ptrdiff_t attention_residual_ln_node_order = + node_index_to_its_order_in_topological_sort_map.at(other_node->Index()); + if (attention_residual_add_node_order >= yield_op_order_in_topological_sort || + attention_residual_ln_node_order >= yield_op_order_in_topological_sort) { + MO_LOG_DEBUG_INFO(logger, "Not a valid residual node arg " + input_arg->Name()); + continue; + } + + // Search all forward nodes that is before `add_node` in topo order, unless we find a softmax or + // nodes_to_check is empty. std::deque nodes_to_check; std::set visited_nodes; for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { // Ignore those nodes after YieldOp. - if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) { + auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index()); + if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) { nodes_to_check.push_back(&(*node_it)); } } - bool unexpected_failure = false; - bool found_softmax = false; - bool found_layernorm = false; - ptrdiff_t next_layernorm_execution_oder = -1; while (!nodes_to_check.empty()) { const Node* next_node = nodes_to_check.front(); nodes_to_check.pop_front(); @@ -62,41 +159,21 @@ void FindLayerBoundaryLayerNormNodes( } visited_nodes.insert(next_node); - if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { - found_softmax = true; - } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - if (found_layernorm) { - // If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection. - unexpected_failure = true; - break; - } - found_layernorm = true; // don't trace further - next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index()); - continue; - } else { - for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { - // Stop if the node is after next Layernorm node in execution order. - if (found_layernorm && - node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) { - continue; - } + if (IsSoftmaxNode(*next_node)) { + MO_LOG_DEBUG_INFO(logger, "Found layer boundary node " + node.Name() + " with its input arg: " + + input_arg->Name()); + layer_boundary_ln_nodes.push_back(&node); + break; + } + + for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { + // Stop if the node is after next Layernorm node in execution order. + auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index()); + if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) { nodes_to_check.push_back(&(*node_it)); } } } - - if (unexpected_failure) { - layer_boundary_ln_nodes.clear(); - break; - } - - if (found_softmax) { - layer_boundary_ln_nodes.insert(&node); - } else if (!found_layernorm) { - // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, - // we also consider it as boundary node. - layer_boundary_ln_nodes.insert(&node); - } } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h index b58d822124..a72e5a0af9 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h @@ -23,6 +23,6 @@ void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const ptrdiff_t& yield_op_order_in_topological_sort, - InlinedHashSet& layer_boundary_ln_nodes); + InlinedVector& layer_boundary_ln_nodes); } // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index 22f1da1327..360095dea6 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -29,6 +29,7 @@ #include "orttraining/core/optimizer/memory_optimizer/common.h" #include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" #include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" +#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -312,5 +313,45 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { } } +TEST(MemoryOptimizerTests, TransformerLayerDetectionTest) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "3layer_bloom_optimized_training.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + GraphViewer graph_viewer(graph); + + InlinedHashMap node_index_to_its_order_in_topological_sort_map; + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + + // Find boundary ops between forward and backward pass, currently, it's limited to YieldOp. + ptrdiff_t yield_op_order_in_topological_sort = -1; + for (size_t i = 0; i < node_ids.size(); ++i) { + const Node* p_node = graph_viewer.GetNode(node_ids[i]); + if (p_node == nullptr) { /* skip removed nodes*/ + continue; + } + + if (p_node->OpType() == "YieldOp") { + // There are multiple YieldOps in the graph。 + ASSERT_EQ(yield_op_order_in_topological_sort, -1); + yield_op_order_in_topological_sort = static_cast(i); + } + + node_index_to_its_order_in_topological_sort_map[p_node->Index()] = static_cast(i); + } + + InlinedVector layer_boundary_ln_node; + optimizer::memory_optimizer::FindLayerBoundaryLayerNormNodes(graph_viewer, *logger, + node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, + layer_boundary_ln_node); + + ASSERT_EQ(layer_boundary_ln_node.size(), 3); + ASSERT_EQ(layer_boundary_ln_node[0]->Name(), "LayerNormalization_token_0"); + ASSERT_EQ(layer_boundary_ln_node[1]->Name(), "LayerNormalization_token_6"); + ASSERT_EQ(layer_boundary_ln_node[2]->Name(), "LayerNormalization_token_12"); +} + } // namespace test } // namespace onnxruntime