From d03e477b9026a97d22dba64cd00b4614603671e5 Mon Sep 17 00:00:00 2001 From: pengwa Date: Thu, 11 Jan 2024 12:50:55 +0800 Subject: [PATCH] Fix missing subgraph candidates for recompute (#19077) ### Fix missing subgraph candidates for recompute For subgraphs for example `MatMul+Transpose+Reshape`, since the ending node is a Reshape, in ORT, it is reusing input buffers. Currently, the subgraph detection logic has defect, as a result, those subgraphs will be missing as recompute candidates. Also append a few more node types for recompute support. TODO: add unit test later. This PR is needed for a customer model now. --- .../memory_optimizer/memory_insight.cc | 34 +++++++++----- .../memory_optimizer/optimization_planner.cc | 29 ------------ .../memory_optimizer/optimization_planner.h | 2 +- .../memory_optimizer/recompute_analysis.cc | 3 ++ .../memory_optimizer/recompute_analysis.h | 45 +++++++++++++++++++ .../ortmodule/_graph_execution_manager.py | 2 +- .../python/training/ortmodule/options.py | 9 ++++ 7 files changed, 83 insertions(+), 41 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 9b77832abb..3fbdd5da7b 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -485,12 +485,15 @@ void ListAllCombinations(const InlinedVector> new_combination = current_combination; - new_combination.push_back(plan); - ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations); - } + const InlinedVector>>& + plan_combination_list_at_cur_index = all_possible_node_optimization_plans[index]; + // For the index-th reused buffer, iterate all possible complete plans. + for (size_t i = 0; i < plan_combination_list_at_cur_index.size(); ++i) { + const auto& plan_combination = plan_combination_list_at_cur_index[i]; + InlinedVector> new_combination = current_combination; + // Append the chosen complete plan and continue exploring the next reused buffer by index + 1. + new_combination.insert(new_combination.end(), plan_combination.begin(), plan_combination.end()); + ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations); } MO_LOG_DEBUG_INFO(logger, "Exit ListAllCombinations"); @@ -520,17 +523,28 @@ void IterateNodeOptimizationPlan(const std::shared_ptr } InlinedVector>>> - all_possible_node_optimization_plans; - all_possible_node_optimization_plans.resize(plan->reuse_buffers.size()); + all_possible_node_optimization_plans(plan->reuse_buffers.size()); size_t i = 0; for (const auto& p : plan->reuse_buffers) { MO_LOG_DEBUG_INFO(logger, ">>>reuse buffer: " + std::to_string(p.first)); - IterateNode(p.second.first, node_to_optimization_plans_map, {}, logger, all_possible_node_optimization_plans[i]); + // If the resued node is part of current node optimization plan, then we just add current combination to the result. + if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise || plan->GetOptimizationType() == OptimizationType::Recompute) { + const auto& recompute_subgraph = + dynamic_cast(plan.get())->GetNodesInTopoOrder(); + if (std::find(recompute_subgraph.begin(), recompute_subgraph.end(), p.second.first) != recompute_subgraph.end()) { + all_possible_node_optimization_plans[i].push_back(current_combination); + } + } + + if (all_possible_node_optimization_plans[i].size() == 0) { + IterateNode(p.second.first, node_to_optimization_plans_map, current_combination, logger, all_possible_node_optimization_plans[i]); + } + ++i; } - ListAllCombinations(all_possible_node_optimization_plans, 0, current_combination, logger, all_combinations); + ListAllCombinations(all_possible_node_optimization_plans, 0, {}, logger, all_combinations); MO_LOG_DEBUG_INFO(logger, "Exit IterateNodeOptimizationPlan: " + plan->GetClusterId()); } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc index 64e99a4a0b..4ce896c535 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc @@ -15,35 +15,6 @@ namespace onnxruntime::optimizer::memory_optimizer { -std::string NodeOptimizationPlanBase::GetMemorySavingSymbolicString() const { - std::string saving_str; - for (auto output_index : activation_output_indices_) { - // If the output is reusing other node's buffer, then no memory saving. - if (reuse_buffers.find(output_index) != reuse_buffers.end()) { - continue; - } - - const auto& output_def = node->OutputDefs()[output_index]; - MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto()); - ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ", - DataTypeImpl::ToString(ml_data_type)); - const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); - ORT_ENFORCE(nullptr != tensor_type_base); - MLDataType elt_type = tensor_type_base->GetElementType(); - const auto byte_count_per_element = elt_type->Size(); - if (!saving_str.empty()) { - saving_str += " + "; - } - saving_str = "(" + GetActivationOutputDimParamString(output_index) + " * " + - std::to_string(byte_count_per_element) + " * " + - std::to_string(GetSaveRatio()) + ")"; - } - if (saving_str.empty()) { - return saving_str; - } - return "(" + saving_str + ")"; -} - Status MemoryOptimizationPlanner::UpdateNodePlansFromExecutionPlan(const GraphViewer& graph_viewer, const OrtValueNameIdxMap& ortvalue_name_to_idx_map, const SequentialExecutionPlan& p_seq_exec_plan) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h index c585b2810b..789f530b29 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h @@ -83,7 +83,7 @@ class NodeOptimizationPlanBase { /** * Get a symbolic string to represent the memory saving for this optimization plan. */ - std::string GetMemorySavingSymbolicString() const; + virtual std::string GetMemorySavingSymbolicString() const = 0; std::string GetActivationOutputDimParamString(size_t index) const { ORT_ENFORCE(activation_output_dim_params_.find(index) != activation_output_dim_params_.end(), diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 52dea571a1..12c83591c0 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -72,12 +72,14 @@ const InlinedHashMap& GetAllowedRecompu {"Add", AllowedRecomputeNodeConfig{{0, 1}}}, {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, {"Div", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Equal", AllowedRecomputeNodeConfig{{0, 1}}}, {"Mul", AllowedRecomputeNodeConfig{{0, 1}}}, {"Sub", AllowedRecomputeNodeConfig{{0, 1}}}, // Data layout /// The shape input is trivial whether it exists or not in backward. {"Reshape", AllowedRecomputeNodeConfig{{0}}}, + {"Shape", AllowedRecomputeNodeConfig{{0}}}, {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, {"Transpose", AllowedRecomputeNodeConfig{{0}}}, {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}}, @@ -92,6 +94,7 @@ const InlinedHashMap& GetAllowedRecompu {"Expand", AllowedRecomputeNodeConfig{{0}}}, {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, {"Gelu", AllowedRecomputeNodeConfig{{0}}}, + {"QuickGelu", AllowedRecomputeNodeConfig{{0}}}, // Ternary elementwise {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index d969383531..ab114d9701 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -86,6 +86,51 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase { std::string GetNodesInTopoOrderStr() const; + std::string GetMemorySavingSymbolicString() const override { + std::string saving_str; + for (auto output_index : GetActivationOutputIndices()) { + // If the output is reusing other node's buffer, then no memory saving. + std::string cur_output_saving_str; + + bool is_reused = reuse_buffers.find(output_index) != reuse_buffers.end(); + bool is_src_node_in_cur_node_subgraph = false; + if (is_reused) { + // Here we assume the src_node is the real owner of the buffer, so we don't need trace further. + const auto* src_node = reuse_buffers.at(output_index).first; + is_src_node_in_cur_node_subgraph = std::find(nodes_in_topological_order_.begin(), + nodes_in_topological_order_.end(), + src_node) != nodes_in_topological_order_.end(); + } + + if (!is_reused || is_src_node_in_cur_node_subgraph) { + // For is_src_node_in_cur_node_subgraph is True, still use the output to calculate the saving, because + // reusing buffer is the same size. + const auto& output_def = node->OutputDefs()[output_index]; + MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto()); + ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ", + DataTypeImpl::ToString(ml_data_type)); + const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); + ORT_ENFORCE(nullptr != tensor_type_base); + MLDataType elt_type = tensor_type_base->GetElementType(); + const auto byte_count_per_element = elt_type->Size(); + cur_output_saving_str = GetActivationOutputDimParamString(output_index) + " * " + + std::to_string(byte_count_per_element) + " * " + + std::to_string(GetSaveRatio()); + } else { + cur_output_saving_str = "0"; + } + + if (!saving_str.empty()) { + saving_str += " + "; + } + + saving_str = "(" + cur_output_saving_str + ")"; + } + + ORT_ENFORCE(!saving_str.empty(), "saving_str should not be empty for node: ", node->OpType(), " ", node->Name()); + return "(" + saving_str + ")"; + } + private: bool compromise_recompute_; InlinedVector nodes_in_topological_order_; diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 76943b9548..853eab61b4 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -243,7 +243,7 @@ class GraphExecutionManager(GraphExecutionInterface): # requires PRIORITY_BASED order to work properly. So we use PRIORITY_BASED order when recompute is enabled. session_options.execution_order = ( onnxruntime.ExecutionOrder.PRIORITY_BASED - if self._runtime_options.memory_optimizer_config != "" + if self._runtime_options.memory_optimizer_is_enabled() else onnxruntime.ExecutionOrder.DEFAULT ) # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index a93f6413b7..bfa38efb34 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -399,3 +399,12 @@ class _RuntimeOptions: if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ: self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1 + + def memory_optimizer_is_enabled(self) -> bool: + """Check whether memory optimizer is enabled.""" + if self.memory_optimization_level == _MemoryOptimizationLevel.USER_SPECIFIED: + return len(self.memory_optimizer_config) > 0 + elif self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + return True + + return False