mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Fix missing subgraph candidates for recompute (#19077)
### Fix missing subgraph candidates for recompute For subgraphs for example `MatMul+Transpose+Reshape`, since the ending node is a Reshape, in ORT, it is reusing input buffers. Currently, the subgraph detection logic has defect, as a result, those subgraphs will be missing as recompute candidates. Also append a few more node types for recompute support. TODO: add unit test later. This PR is needed for a customer model now.
This commit is contained in:
parent
0a0ef958eb
commit
d03e477b90
7 changed files with 83 additions and 41 deletions
|
|
@ -485,12 +485,15 @@ void ListAllCombinations(const InlinedVector<InlinedVector<InlinedVector<std::sh
|
|||
return;
|
||||
}
|
||||
|
||||
for (const auto& plans : all_possible_node_optimization_plans[index]) {
|
||||
for (const auto& plan : plans) {
|
||||
InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>> new_combination = current_combination;
|
||||
new_combination.push_back(plan);
|
||||
ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations);
|
||||
}
|
||||
const InlinedVector<InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>>>&
|
||||
plan_combination_list_at_cur_index = all_possible_node_optimization_plans[index];
|
||||
// For the index-th reused buffer, iterate all possible complete plans.
|
||||
for (size_t i = 0; i < plan_combination_list_at_cur_index.size(); ++i) {
|
||||
const auto& plan_combination = plan_combination_list_at_cur_index[i];
|
||||
InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>> new_combination = current_combination;
|
||||
// Append the chosen complete plan and continue exploring the next reused buffer by index + 1.
|
||||
new_combination.insert(new_combination.end(), plan_combination.begin(), plan_combination.end());
|
||||
ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations);
|
||||
}
|
||||
|
||||
MO_LOG_DEBUG_INFO(logger, "Exit ListAllCombinations");
|
||||
|
|
@ -520,17 +523,28 @@ void IterateNodeOptimizationPlan(const std::shared_ptr<NodeOptimizationPlanBase>
|
|||
}
|
||||
|
||||
InlinedVector<InlinedVector<InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>>>>
|
||||
all_possible_node_optimization_plans;
|
||||
all_possible_node_optimization_plans.resize(plan->reuse_buffers.size());
|
||||
all_possible_node_optimization_plans(plan->reuse_buffers.size());
|
||||
|
||||
size_t i = 0;
|
||||
for (const auto& p : plan->reuse_buffers) {
|
||||
MO_LOG_DEBUG_INFO(logger, ">>>reuse buffer: " + std::to_string(p.first));
|
||||
IterateNode(p.second.first, node_to_optimization_plans_map, {}, logger, all_possible_node_optimization_plans[i]);
|
||||
// If the resued node is part of current node optimization plan, then we just add current combination to the result.
|
||||
if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise || plan->GetOptimizationType() == OptimizationType::Recompute) {
|
||||
const auto& recompute_subgraph =
|
||||
dynamic_cast<NodeRecomputePlan*>(plan.get())->GetNodesInTopoOrder();
|
||||
if (std::find(recompute_subgraph.begin(), recompute_subgraph.end(), p.second.first) != recompute_subgraph.end()) {
|
||||
all_possible_node_optimization_plans[i].push_back(current_combination);
|
||||
}
|
||||
}
|
||||
|
||||
if (all_possible_node_optimization_plans[i].size() == 0) {
|
||||
IterateNode(p.second.first, node_to_optimization_plans_map, current_combination, logger, all_possible_node_optimization_plans[i]);
|
||||
}
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
ListAllCombinations(all_possible_node_optimization_plans, 0, current_combination, logger, all_combinations);
|
||||
ListAllCombinations(all_possible_node_optimization_plans, 0, {}, logger, all_combinations);
|
||||
|
||||
MO_LOG_DEBUG_INFO(logger, "Exit IterateNodeOptimizationPlan: " + plan->GetClusterId());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,35 +15,6 @@
|
|||
|
||||
namespace onnxruntime::optimizer::memory_optimizer {
|
||||
|
||||
std::string NodeOptimizationPlanBase::GetMemorySavingSymbolicString() const {
|
||||
std::string saving_str;
|
||||
for (auto output_index : activation_output_indices_) {
|
||||
// If the output is reusing other node's buffer, then no memory saving.
|
||||
if (reuse_buffers.find(output_index) != reuse_buffers.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& output_def = node->OutputDefs()[output_index];
|
||||
MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto());
|
||||
ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ",
|
||||
DataTypeImpl::ToString(ml_data_type));
|
||||
const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType();
|
||||
ORT_ENFORCE(nullptr != tensor_type_base);
|
||||
MLDataType elt_type = tensor_type_base->GetElementType();
|
||||
const auto byte_count_per_element = elt_type->Size();
|
||||
if (!saving_str.empty()) {
|
||||
saving_str += " + ";
|
||||
}
|
||||
saving_str = "(" + GetActivationOutputDimParamString(output_index) + " * " +
|
||||
std::to_string(byte_count_per_element) + " * " +
|
||||
std::to_string(GetSaveRatio()) + ")";
|
||||
}
|
||||
if (saving_str.empty()) {
|
||||
return saving_str;
|
||||
}
|
||||
return "(" + saving_str + ")";
|
||||
}
|
||||
|
||||
Status MemoryOptimizationPlanner::UpdateNodePlansFromExecutionPlan(const GraphViewer& graph_viewer,
|
||||
const OrtValueNameIdxMap& ortvalue_name_to_idx_map,
|
||||
const SequentialExecutionPlan& p_seq_exec_plan) {
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class NodeOptimizationPlanBase {
|
|||
/**
|
||||
* Get a symbolic string to represent the memory saving for this optimization plan.
|
||||
*/
|
||||
std::string GetMemorySavingSymbolicString() const;
|
||||
virtual std::string GetMemorySavingSymbolicString() const = 0;
|
||||
|
||||
std::string GetActivationOutputDimParamString(size_t index) const {
|
||||
ORT_ENFORCE(activation_output_dim_params_.find(index) != activation_output_dim_params_.end(),
|
||||
|
|
|
|||
|
|
@ -72,12 +72,14 @@ const InlinedHashMap<std::string, AllowedRecomputeNodeConfig>& GetAllowedRecompu
|
|||
{"Add", AllowedRecomputeNodeConfig{{0, 1}}},
|
||||
{"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}},
|
||||
{"Div", AllowedRecomputeNodeConfig{{0, 1}}},
|
||||
{"Equal", AllowedRecomputeNodeConfig{{0, 1}}},
|
||||
{"Mul", AllowedRecomputeNodeConfig{{0, 1}}},
|
||||
{"Sub", AllowedRecomputeNodeConfig{{0, 1}}},
|
||||
|
||||
// Data layout
|
||||
/// The shape input is trivial whether it exists or not in backward.
|
||||
{"Reshape", AllowedRecomputeNodeConfig{{0}}},
|
||||
{"Shape", AllowedRecomputeNodeConfig{{0}}},
|
||||
{"Squeeze", AllowedRecomputeNodeConfig{{0}}},
|
||||
{"Transpose", AllowedRecomputeNodeConfig{{0}}},
|
||||
{"Unsqueeze", AllowedRecomputeNodeConfig{{0}}},
|
||||
|
|
@ -92,6 +94,7 @@ const InlinedHashMap<std::string, AllowedRecomputeNodeConfig>& GetAllowedRecompu
|
|||
{"Expand", AllowedRecomputeNodeConfig{{0}}},
|
||||
{"FastGelu", AllowedRecomputeNodeConfig{{0}}},
|
||||
{"Gelu", AllowedRecomputeNodeConfig{{0}}},
|
||||
{"QuickGelu", AllowedRecomputeNodeConfig{{0}}},
|
||||
|
||||
// Ternary elementwise
|
||||
{"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}},
|
||||
|
|
|
|||
|
|
@ -86,6 +86,51 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase {
|
|||
|
||||
std::string GetNodesInTopoOrderStr() const;
|
||||
|
||||
std::string GetMemorySavingSymbolicString() const override {
|
||||
std::string saving_str;
|
||||
for (auto output_index : GetActivationOutputIndices()) {
|
||||
// If the output is reusing other node's buffer, then no memory saving.
|
||||
std::string cur_output_saving_str;
|
||||
|
||||
bool is_reused = reuse_buffers.find(output_index) != reuse_buffers.end();
|
||||
bool is_src_node_in_cur_node_subgraph = false;
|
||||
if (is_reused) {
|
||||
// Here we assume the src_node is the real owner of the buffer, so we don't need trace further.
|
||||
const auto* src_node = reuse_buffers.at(output_index).first;
|
||||
is_src_node_in_cur_node_subgraph = std::find(nodes_in_topological_order_.begin(),
|
||||
nodes_in_topological_order_.end(),
|
||||
src_node) != nodes_in_topological_order_.end();
|
||||
}
|
||||
|
||||
if (!is_reused || is_src_node_in_cur_node_subgraph) {
|
||||
// For is_src_node_in_cur_node_subgraph is True, still use the output to calculate the saving, because
|
||||
// reusing buffer is the same size.
|
||||
const auto& output_def = node->OutputDefs()[output_index];
|
||||
MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto());
|
||||
ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ",
|
||||
DataTypeImpl::ToString(ml_data_type));
|
||||
const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType();
|
||||
ORT_ENFORCE(nullptr != tensor_type_base);
|
||||
MLDataType elt_type = tensor_type_base->GetElementType();
|
||||
const auto byte_count_per_element = elt_type->Size();
|
||||
cur_output_saving_str = GetActivationOutputDimParamString(output_index) + " * " +
|
||||
std::to_string(byte_count_per_element) + " * " +
|
||||
std::to_string(GetSaveRatio());
|
||||
} else {
|
||||
cur_output_saving_str = "0";
|
||||
}
|
||||
|
||||
if (!saving_str.empty()) {
|
||||
saving_str += " + ";
|
||||
}
|
||||
|
||||
saving_str = "(" + cur_output_saving_str + ")";
|
||||
}
|
||||
|
||||
ORT_ENFORCE(!saving_str.empty(), "saving_str should not be empty for node: ", node->OpType(), " ", node->Name());
|
||||
return "(" + saving_str + ")";
|
||||
}
|
||||
|
||||
private:
|
||||
bool compromise_recompute_;
|
||||
InlinedVector<const Node*> nodes_in_topological_order_;
|
||||
|
|
|
|||
|
|
@ -243,7 +243,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# requires PRIORITY_BASED order to work properly. So we use PRIORITY_BASED order when recompute is enabled.
|
||||
session_options.execution_order = (
|
||||
onnxruntime.ExecutionOrder.PRIORITY_BASED
|
||||
if self._runtime_options.memory_optimizer_config != ""
|
||||
if self._runtime_options.memory_optimizer_is_enabled()
|
||||
else onnxruntime.ExecutionOrder.DEFAULT
|
||||
)
|
||||
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
||||
|
|
|
|||
|
|
@ -399,3 +399,12 @@ class _RuntimeOptions:
|
|||
|
||||
if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ:
|
||||
self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1
|
||||
|
||||
def memory_optimizer_is_enabled(self) -> bool:
|
||||
"""Check whether memory optimizer is enabled."""
|
||||
if self.memory_optimization_level == _MemoryOptimizationLevel.USER_SPECIFIED:
|
||||
return len(self.memory_optimizer_config) > 0
|
||||
elif self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue