diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 54c49db059..3d0fa942fd 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -257,10 +257,15 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, is_forward_nodes, logger)); - InlinedHashSet layer_boundary_ln_nodes; + InlinedVector layer_boundary_ln_nodes; FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map, yield_op_order_in_topological_sort, layer_boundary_ln_nodes); + if (probe_config.enable_transformer_layer_as_boundary && layer_boundary_ln_nodes.size() == 0) { + LOGS(logger, WARNING) << "No transformer layer boundary nodes found, this might cause memory optimization " + "not working as expected. Please check the model and the configuration."; + } + // The first pass - find the candidate subgraphs. for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { const Node* p_node = graph_viewer.GetNode(node_ids[i]); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index b421eb2ab3..37ac1c4950 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -386,6 +386,26 @@ const InlinedHashMap& GetAllowedRecompu {1, {}}, }, }, + { + utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have SimplifiedLayerNormalization, + // while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain. + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, { utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), { @@ -691,7 +711,7 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& candidate_output_args_map, - const InlinedHashSet& layer_boundary_ln_nodes, + const InlinedVector& layer_boundary_ln_nodes, const logging::Logger& logger, bool compromise_stashed_activation, bool& can_compromise_stashed_activation) { @@ -709,13 +729,14 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap auto output_name = node.OutputDefs()[output_index]->Name(); auto consumers = graph_viewer.GetConsumerNodes(output_name); for (auto& consumer : consumers) { - if (layer_boundary_ln_nodes.find(consumer) != layer_boundary_ln_nodes.end()) { + if (std::find(layer_boundary_ln_nodes.begin(), layer_boundary_ln_nodes.end(), consumer) != + layer_boundary_ln_nodes.end()) { int dest_in_index = optimizer_utils::IndexOfNodeInput(*consumer, *node.OutputDefs()[output_index]); if (dest_in_index == 0) { - LOGS(logger, INFO) << "Node " << node.Name() << "(" << node.OpType() - << ") is a Attention+MLP layer boundary node, " - << "its stashed activation outputs are used by LayerNormalization's inputs, " - << "we don't need to recompute it."; + MO_LOG_DEBUG_INFO(logger, "Node " + node.Name() + "(" + node.OpType() + + ") is a Attention+MLP layer boundary node, " + + "its stashed activation outputs are used by LayerNormalization's inputs, " + + "we don't need to recompute it."); return nullptr; } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index ab114d9701..ac1021f5eb 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -164,7 +164,7 @@ std::unique_ptr CheckNodeForRecompute(const GraphViewer& grap node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& candidate_output_args_map, - const InlinedHashSet& layer_boundary_ln_nodes, + const InlinedVector& layer_boundary_ln_nodes, const logging::Logger& logger, bool compromise_stashed_activation, bool& can_compromise_stashed_activation); diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index c88a0f05d3..3bcfbd324b 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include #include @@ -16,43 +17,139 @@ namespace onnxruntime::optimizer::memory_optimizer { +namespace { + +bool IsLayerNormNode(const Node& node) { + static const std::set layer_norm_ops = { + "LayerNormalization", + "SkipLayerNormalization", + "SimplifiedLayerNormalization", + "SkipSimplifiedLayerNormalization", + }; + return layer_norm_ops.find(node.OpType()) != layer_norm_ops.end(); +} + +bool IsSoftmaxNode(const Node& node) { + static const std::set softmax_ops = { + "Softmax", + "BiasSoftmax", + }; + return softmax_ops.find(node.OpType()) != softmax_ops.end(); +} + +std::tuple IsResidualNodeArg(const GraphViewer& graph_viewer, const NodeArg* node_arg) { + auto consumers = graph_viewer.GetConsumerNodes(node_arg->Name()); + if (2 > consumers.size()) { + return std::make_tuple(false, nullptr, nullptr); + } + + // Find the Add node from the consumer list. + const Node* add_node = nullptr; + const Node* other_node = nullptr; + for (const auto* consumer : consumers) { + if (consumer->OpType() == "Add") { + add_node = consumer; + } else if (IsLayerNormNode(*consumer)) { + other_node = consumer; + } + } + + return std::make_tuple(add_node != nullptr && other_node != nullptr, add_node, other_node); +} +} // namespace + +/* + One classical layer includes 1). input layer norm, 2). attention, 3). residual add + (input layer norm input + attention output), 4). post attention layer norm feedforward, and 5). residual add + (post attention layer norm input + feedforward out). + + The pattern graph looks like below for each transformer layer (taking the example of MistralDecoderLayer): + | + Embedding + | + ----------------------| + | | + | | + | SimplifiedLayerNormalization (layer boudary node) + | | + | | + | MistralAttention + | | + | | + |____________________Add + | + ----------------------| + | | + | | + | SimplifiedLayerNormalization + | | + | | + | MultipleLayerPerception + | | + | | + |____________________Add + | + (new layer) + ----------------------| + | | + | SimplifiedLayerNormalization + .... +*/ void FindLayerBoundaryLayerNormNodes( const GraphViewer& graph_viewer, - const logging::Logger&, + const logging::Logger& logger, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const ptrdiff_t& yield_op_order_in_topological_sort, - InlinedHashSet& layer_boundary_ln_nodes) { + InlinedVector& layer_boundary_ln_nodes) { // Loop all nodes to find LayerNormalization nodes. // For each LayerNormalization node, keep checking its output nodes, // until find a node that is Softmax or BiasSoftmax or another LayerNormalization. // If the found node is Softmax or BiasSoftmax, the LayerNormalization node as ATTENTION. // If the found node is another LayerNormalization, the LayerNormalization node as MLP. - const InlinedHashSet softmax_ops{"Softmax", "BiasSoftmax"}; - const InlinedHashSet layernorm_ops{"LayerNormalization", "SkipLayerNormalization"}; layer_boundary_ln_nodes.clear(); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); for (auto node_index : node_topology_list) { auto& node = *graph_viewer.GetNode(node_index); - if (layernorm_ops.find(node.OpType()) == layernorm_ops.end()) { + if (!IsLayerNormNode(node)) { + continue; + } + const NodeArg* input_arg = node.InputDefs()[0]; + + // IsResidualNodeArg checks input_arg + auto [is_residual_node_arg, add_node, other_node] = IsResidualNodeArg(graph_viewer, input_arg); + if (!is_residual_node_arg) { + MO_LOG_DEBUG_INFO(logger, "Not a residual node arg " + input_arg->Name()); continue; } + // At this point, there should not be any recompute node, so we don't need check the node existence in + // node_index_to_its_order_in_topological_sort_map. + ptrdiff_t attention_residual_add_node_order = + node_index_to_its_order_in_topological_sort_map.at(add_node->Index()); + ptrdiff_t attention_residual_ln_node_order = + node_index_to_its_order_in_topological_sort_map.at(other_node->Index()); + if (attention_residual_add_node_order >= yield_op_order_in_topological_sort || + attention_residual_ln_node_order >= yield_op_order_in_topological_sort) { + MO_LOG_DEBUG_INFO(logger, "Not a valid residual node arg " + input_arg->Name()); + continue; + } + + // Search all forward nodes that is before `add_node` in topo order, unless we find a softmax or + // nodes_to_check is empty. std::deque nodes_to_check; std::set visited_nodes; for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { // Ignore those nodes after YieldOp. - if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) { + auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index()); + if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) { nodes_to_check.push_back(&(*node_it)); } } - bool unexpected_failure = false; - bool found_softmax = false; - bool found_layernorm = false; - ptrdiff_t next_layernorm_execution_oder = -1; while (!nodes_to_check.empty()) { const Node* next_node = nodes_to_check.front(); nodes_to_check.pop_front(); @@ -62,41 +159,21 @@ void FindLayerBoundaryLayerNormNodes( } visited_nodes.insert(next_node); - if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { - found_softmax = true; - } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - if (found_layernorm) { - // If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection. - unexpected_failure = true; - break; - } - found_layernorm = true; // don't trace further - next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index()); - continue; - } else { - for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { - // Stop if the node is after next Layernorm node in execution order. - if (found_layernorm && - node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) { - continue; - } + if (IsSoftmaxNode(*next_node)) { + MO_LOG_DEBUG_INFO(logger, "Found layer boundary node " + node.Name() + " with its input arg: " + + input_arg->Name()); + layer_boundary_ln_nodes.push_back(&node); + break; + } + + for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { + // Stop if the node is after next Layernorm node in execution order. + auto order = node_index_to_its_order_in_topological_sort_map.at(node_it->Index()); + if (order < yield_op_order_in_topological_sort && order < attention_residual_add_node_order) { nodes_to_check.push_back(&(*node_it)); } } } - - if (unexpected_failure) { - layer_boundary_ln_nodes.clear(); - break; - } - - if (found_softmax) { - layer_boundary_ln_nodes.insert(&node); - } else if (!found_layernorm) { - // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, - // we also consider it as boundary node. - layer_boundary_ln_nodes.insert(&node); - } } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h index b58d822124..a72e5a0af9 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h @@ -23,6 +23,6 @@ void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer, const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const ptrdiff_t& yield_op_order_in_topological_sort, - InlinedHashSet& layer_boundary_ln_nodes); + InlinedVector& layer_boundary_ln_nodes); } // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index 22f1da1327..360095dea6 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -29,6 +29,7 @@ #include "orttraining/core/optimizer/memory_optimizer/common.h" #include "orttraining/core/optimizer/memory_optimizer/memory_optimizer.h" #include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" +#include "orttraining/core/optimizer/memory_optimizer/transformer_specific.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -312,5 +313,45 @@ TEST(MemoryOptimizerTests, TransformerPerLayerRecompute) { } } +TEST(MemoryOptimizerTests, TransformerLayerDetectionTest) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "3layer_bloom_optimized_training.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + GraphViewer graph_viewer(graph); + + InlinedHashMap node_index_to_its_order_in_topological_sort_map; + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + + // Find boundary ops between forward and backward pass, currently, it's limited to YieldOp. + ptrdiff_t yield_op_order_in_topological_sort = -1; + for (size_t i = 0; i < node_ids.size(); ++i) { + const Node* p_node = graph_viewer.GetNode(node_ids[i]); + if (p_node == nullptr) { /* skip removed nodes*/ + continue; + } + + if (p_node->OpType() == "YieldOp") { + // There are multiple YieldOps in the graph。 + ASSERT_EQ(yield_op_order_in_topological_sort, -1); + yield_op_order_in_topological_sort = static_cast(i); + } + + node_index_to_its_order_in_topological_sort_map[p_node->Index()] = static_cast(i); + } + + InlinedVector layer_boundary_ln_node; + optimizer::memory_optimizer::FindLayerBoundaryLayerNormNodes(graph_viewer, *logger, + node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, + layer_boundary_ln_node); + + ASSERT_EQ(layer_boundary_ln_node.size(), 3); + ASSERT_EQ(layer_boundary_ln_node[0]->Name(), "LayerNormalization_token_0"); + ASSERT_EQ(layer_boundary_ln_node[1]->Name(), "LayerNormalization_token_6"); + ASSERT_EQ(layer_boundary_ln_node[2]->Name(), "LayerNormalization_token_12"); +} + } // namespace test } // namespace onnxruntime