Fix missing subgraph candidates for recompute (#19077)

### Fix missing subgraph candidates for recompute

For subgraphs for example `MatMul+Transpose+Reshape`, since the ending
node is a Reshape, in ORT, it is reusing input buffers.

Currently, the subgraph detection logic has defect, as a result, those
subgraphs will be missing as recompute candidates.

Also append a few more node types for recompute support. 

TODO: add unit test later. This PR is needed for a customer model now.
This commit is contained in:
pengwa 2024-01-11 12:50:55 +08:00 committed by GitHub
parent 0a0ef958eb
commit d03e477b90
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 83 additions and 41 deletions

View file

@ -485,12 +485,15 @@ void ListAllCombinations(const InlinedVector<InlinedVector<InlinedVector<std::sh
return;
}
for (const auto& plans : all_possible_node_optimization_plans[index]) {
for (const auto& plan : plans) {
InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>> new_combination = current_combination;
new_combination.push_back(plan);
ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations);
}
const InlinedVector<InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>>>&
plan_combination_list_at_cur_index = all_possible_node_optimization_plans[index];
// For the index-th reused buffer, iterate all possible complete plans.
for (size_t i = 0; i < plan_combination_list_at_cur_index.size(); ++i) {
const auto& plan_combination = plan_combination_list_at_cur_index[i];
InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>> new_combination = current_combination;
// Append the chosen complete plan and continue exploring the next reused buffer by index + 1.
new_combination.insert(new_combination.end(), plan_combination.begin(), plan_combination.end());
ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations);
}
MO_LOG_DEBUG_INFO(logger, "Exit ListAllCombinations");
@ -520,17 +523,28 @@ void IterateNodeOptimizationPlan(const std::shared_ptr<NodeOptimizationPlanBase>
}
InlinedVector<InlinedVector<InlinedVector<std::shared_ptr<NodeOptimizationPlanBase>>>>
all_possible_node_optimization_plans;
all_possible_node_optimization_plans.resize(plan->reuse_buffers.size());
all_possible_node_optimization_plans(plan->reuse_buffers.size());
size_t i = 0;
for (const auto& p : plan->reuse_buffers) {
MO_LOG_DEBUG_INFO(logger, ">>>reuse buffer: " + std::to_string(p.first));
IterateNode(p.second.first, node_to_optimization_plans_map, {}, logger, all_possible_node_optimization_plans[i]);
// If the resued node is part of current node optimization plan, then we just add current combination to the result.
if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise || plan->GetOptimizationType() == OptimizationType::Recompute) {
const auto& recompute_subgraph =
dynamic_cast<NodeRecomputePlan*>(plan.get())->GetNodesInTopoOrder();
if (std::find(recompute_subgraph.begin(), recompute_subgraph.end(), p.second.first) != recompute_subgraph.end()) {
all_possible_node_optimization_plans[i].push_back(current_combination);
}
}
if (all_possible_node_optimization_plans[i].size() == 0) {
IterateNode(p.second.first, node_to_optimization_plans_map, current_combination, logger, all_possible_node_optimization_plans[i]);
}
++i;
}
ListAllCombinations(all_possible_node_optimization_plans, 0, current_combination, logger, all_combinations);
ListAllCombinations(all_possible_node_optimization_plans, 0, {}, logger, all_combinations);
MO_LOG_DEBUG_INFO(logger, "Exit IterateNodeOptimizationPlan: " + plan->GetClusterId());
}

View file

@ -15,35 +15,6 @@
namespace onnxruntime::optimizer::memory_optimizer {
std::string NodeOptimizationPlanBase::GetMemorySavingSymbolicString() const {
std::string saving_str;
for (auto output_index : activation_output_indices_) {
// If the output is reusing other node's buffer, then no memory saving.
if (reuse_buffers.find(output_index) != reuse_buffers.end()) {
continue;
}
const auto& output_def = node->OutputDefs()[output_index];
MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto());
ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ",
DataTypeImpl::ToString(ml_data_type));
const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType();
ORT_ENFORCE(nullptr != tensor_type_base);
MLDataType elt_type = tensor_type_base->GetElementType();
const auto byte_count_per_element = elt_type->Size();
if (!saving_str.empty()) {
saving_str += " + ";
}
saving_str = "(" + GetActivationOutputDimParamString(output_index) + " * " +
std::to_string(byte_count_per_element) + " * " +
std::to_string(GetSaveRatio()) + ")";
}
if (saving_str.empty()) {
return saving_str;
}
return "(" + saving_str + ")";
}
Status MemoryOptimizationPlanner::UpdateNodePlansFromExecutionPlan(const GraphViewer& graph_viewer,
const OrtValueNameIdxMap& ortvalue_name_to_idx_map,
const SequentialExecutionPlan& p_seq_exec_plan) {

View file

@ -83,7 +83,7 @@ class NodeOptimizationPlanBase {
/**
* Get a symbolic string to represent the memory saving for this optimization plan.
*/
std::string GetMemorySavingSymbolicString() const;
virtual std::string GetMemorySavingSymbolicString() const = 0;
std::string GetActivationOutputDimParamString(size_t index) const {
ORT_ENFORCE(activation_output_dim_params_.find(index) != activation_output_dim_params_.end(),

View file

@ -72,12 +72,14 @@ const InlinedHashMap<std::string, AllowedRecomputeNodeConfig>& GetAllowedRecompu
{"Add", AllowedRecomputeNodeConfig{{0, 1}}},
{"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}},
{"Div", AllowedRecomputeNodeConfig{{0, 1}}},
{"Equal", AllowedRecomputeNodeConfig{{0, 1}}},
{"Mul", AllowedRecomputeNodeConfig{{0, 1}}},
{"Sub", AllowedRecomputeNodeConfig{{0, 1}}},
// Data layout
/// The shape input is trivial whether it exists or not in backward.
{"Reshape", AllowedRecomputeNodeConfig{{0}}},
{"Shape", AllowedRecomputeNodeConfig{{0}}},
{"Squeeze", AllowedRecomputeNodeConfig{{0}}},
{"Transpose", AllowedRecomputeNodeConfig{{0}}},
{"Unsqueeze", AllowedRecomputeNodeConfig{{0}}},
@ -92,6 +94,7 @@ const InlinedHashMap<std::string, AllowedRecomputeNodeConfig>& GetAllowedRecompu
{"Expand", AllowedRecomputeNodeConfig{{0}}},
{"FastGelu", AllowedRecomputeNodeConfig{{0}}},
{"Gelu", AllowedRecomputeNodeConfig{{0}}},
{"QuickGelu", AllowedRecomputeNodeConfig{{0}}},
// Ternary elementwise
{"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}},

View file

@ -86,6 +86,51 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase {
std::string GetNodesInTopoOrderStr() const;
std::string GetMemorySavingSymbolicString() const override {
std::string saving_str;
for (auto output_index : GetActivationOutputIndices()) {
// If the output is reusing other node's buffer, then no memory saving.
std::string cur_output_saving_str;
bool is_reused = reuse_buffers.find(output_index) != reuse_buffers.end();
bool is_src_node_in_cur_node_subgraph = false;
if (is_reused) {
// Here we assume the src_node is the real owner of the buffer, so we don't need trace further.
const auto* src_node = reuse_buffers.at(output_index).first;
is_src_node_in_cur_node_subgraph = std::find(nodes_in_topological_order_.begin(),
nodes_in_topological_order_.end(),
src_node) != nodes_in_topological_order_.end();
}
if (!is_reused || is_src_node_in_cur_node_subgraph) {
// For is_src_node_in_cur_node_subgraph is True, still use the output to calculate the saving, because
// reusing buffer is the same size.
const auto& output_def = node->OutputDefs()[output_index];
MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto());
ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ",
DataTypeImpl::ToString(ml_data_type));
const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType();
ORT_ENFORCE(nullptr != tensor_type_base);
MLDataType elt_type = tensor_type_base->GetElementType();
const auto byte_count_per_element = elt_type->Size();
cur_output_saving_str = GetActivationOutputDimParamString(output_index) + " * " +
std::to_string(byte_count_per_element) + " * " +
std::to_string(GetSaveRatio());
} else {
cur_output_saving_str = "0";
}
if (!saving_str.empty()) {
saving_str += " + ";
}
saving_str = "(" + cur_output_saving_str + ")";
}
ORT_ENFORCE(!saving_str.empty(), "saving_str should not be empty for node: ", node->OpType(), " ", node->Name());
return "(" + saving_str + ")";
}
private:
bool compromise_recompute_;
InlinedVector<const Node*> nodes_in_topological_order_;

View file

@ -243,7 +243,7 @@ class GraphExecutionManager(GraphExecutionInterface):
# requires PRIORITY_BASED order to work properly. So we use PRIORITY_BASED order when recompute is enabled.
session_options.execution_order = (
onnxruntime.ExecutionOrder.PRIORITY_BASED
if self._runtime_options.memory_optimizer_config != ""
if self._runtime_options.memory_optimizer_is_enabled()
else onnxruntime.ExecutionOrder.DEFAULT
)
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.

View file

@ -399,3 +399,12 @@ class _RuntimeOptions:
if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ:
self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1
def memory_optimizer_is_enabled(self) -> bool:
"""Check whether memory optimizer is enabled."""
if self.memory_optimization_level == _MemoryOptimizationLevel.USER_SPECIFIED:
return len(self.memory_optimizer_config) > 0
elif self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE:
return True
return False