diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 08c402bf66..54c49db059 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -258,7 +258,8 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, logger)); InlinedHashSet layer_boundary_ln_nodes; - FindLayerBoundaryLayerNormNodes(graph_viewer, logger, layer_boundary_ln_nodes); + FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, layer_boundary_ln_nodes); // The first pass - find the candidate subgraphs. for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index 525e3b4b8d..40fa2fc5cc 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -190,11 +190,44 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve .IsOK()); // The second pass - apply the transformation. - // Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. + // Note 1: Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier // layers. - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + // + // Note 2: Here we use default typo order (which tries to BFS from the outputs, + // so the nearest node to graph output will be visited last). So in reversed default typo order, + // the neareast node to graph output will be visited first. + // Imagine there is a such subgraph + // input1 input2 input3 + // \ | / + // multiple layers + // | + // node M + // labels-------|----- + // \ | | + // node1 | | + // \ | | + // node2 / | + // \ / | + // node loss / + // | / + // YieldOp node1_recompute + // | / + // \ node2 recompute + // \ / + // node loss_grad + // | + // critical grad path + // + // In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added + // at last because we do this following reversed topological order. Then node1_recompute node will have lowest + // priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then + // node1_recompute will be run at last, affecting the backward critical path, which is not what we want. + // Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes + // in this case. + + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 12c83591c0..76b3325f36 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -19,7 +19,7 @@ namespace onnxruntime::optimizer::memory_optimizer { namespace { -constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15; +constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 50; static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) { const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); @@ -291,6 +291,22 @@ Status SelectRecomputeSubgraph(const Node& entry_node, const auto current_node_input_index = input_edge.GetDstArgIndex(); if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != input_arg_indices.end()) { + // If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue. + auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape(); + if (output_shape) { + bool all_constant_dim = true; + int64_t num_elem = 1; + for (int k = 0, dim_size = output_shape->dim_size(); k < dim_size; ++k) { + if (!output_shape->dim(k).has_dim_value()) { + all_constant_dim = false; + num_elem *= output_shape->dim(k).dim_value(); + } + } + if (all_constant_dim && num_elem < 1 * 1024 * 1024) { + // Skip this input index. + continue; + } + } NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index); MO_LOG_DEBUG_INFO(logger, "Node " + parent_node.Name() + "(" + parent_node.OpType() + ")'s " + diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index 04f2679ac7..c88a0f05d3 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -19,6 +19,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes( const GraphViewer& graph_viewer, const logging::Logger&, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes) { // Loop all nodes to find LayerNormalization nodes. // For each LayerNormalization node, keep checking its output nodes, @@ -40,9 +43,16 @@ void FindLayerBoundaryLayerNormNodes( std::deque nodes_to_check; std::set visited_nodes; for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { - nodes_to_check.push_back(&(*node_it)); + // Ignore those nodes after YieldOp. + if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) { + nodes_to_check.push_back(&(*node_it)); + } } + bool unexpected_failure = false; + bool found_softmax = false; + bool found_layernorm = false; + ptrdiff_t next_layernorm_execution_oder = -1; while (!nodes_to_check.empty()) { const Node* next_node = nodes_to_check.front(); nodes_to_check.pop_front(); @@ -53,16 +63,40 @@ void FindLayerBoundaryLayerNormNodes( visited_nodes.insert(next_node); if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { - layer_boundary_ln_nodes.insert(&node); - break; + found_softmax = true; } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - break; + if (found_layernorm) { + // If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection. + unexpected_failure = true; + break; + } + found_layernorm = true; // don't trace further + next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index()); + continue; } else { for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { + // Stop if the node is after next Layernorm node in execution order. + if (found_layernorm && + node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) { + continue; + } nodes_to_check.push_back(&(*node_it)); } } } + + if (unexpected_failure) { + layer_boundary_ln_nodes.clear(); + break; + } + + if (found_softmax) { + layer_boundary_ln_nodes.insert(&node); + } else if (!found_layernorm) { + // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, + // we also consider it as boundary node. + layer_boundary_ln_nodes.insert(&node); + } } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h index f2cfd640b0..b58d822124 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h @@ -20,6 +20,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer, const logging::Logger& logger, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes); } // namespace onnxruntime::optimizer::memory_optimizer