Trade subgraph recompute for memory (#12852)

**Description**: Subgraph-level recompute

This PR adds an optional capability trading additional re-computation
for better memory efficiency. Specifically, a pre-defined operator list
used to iterate the Graph to find some subgraphs for recompute, to
reduce some stashed activations whose lifetime across forward and
backward pass.

When training with ORTModule, by default, the graph transformer will
scan the execution graph to find all eligible subgraph to recompute,
along with sizes that can save. An example looks like below.
If we want to enable some of them to recompute, we can define env
variable this way:
`export
ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"`
```

[1,0]<stderr>:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary]
[1,0]<stderr>:MemoryAlleviation Summary:
[1,0]<stderr>:  User config:
[1,0]<stderr>:  Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1
[1,0]<stderr>:  =================================
[1,0]<stderr>:  Subgraph: BitmaskDropout+
[1,0]<stderr>:          AlleviationType: Disabled
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 1,024 x   Frequency:1
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: BiasGelu+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x  Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Reshape[1,0]<stderr>:+
[1,0]<stderr>:          AlleviationType: Disabled
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:labels_dim0 x      Frequency:1
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>:          AlleviationType: Disabled
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x    Frequency:23
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x    Frequency:1
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Mul+Add+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x         Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+
[1,0]<stderr>:          AlleviationType: Disabled
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x     Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Mul+Sub+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x         Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: Cast+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:1,024 x 1,024 x    Frequency:97
[1,0]<stderr>:                  PatternShape:3 x 1,024 x        Frequency:1
[1,0]<stderr>:                  PatternShape:8 x 64 x   Frequency:24
[1,0]<stderr>:                  PatternShape:1,024 x 4,096 x    Frequency:24
[1,0]<stderr>:                  PatternShape:4,096 x    Frequency:24
[1,0]<stderr>:                  PatternShape:4,096 x 1,024 x    Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  Subgraph: FusedMatMul+
[1,0]<stderr>:          AlleviationType: Recompute
[1,0]<stderr>:          Patterns:
[1,0]<stderr>:                  PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x  Frequency:24
[1,0]<stderr>:  --------------------------------
[1,0]<stderr>:  =================================
```


"Type config:" whether recompute is enabled by users. 0 - disable, 1-
enable.
"Subgraph" means what kind of subgraph will be recomputed, in this case,
it is a single node "Gelu", and it will be "Recompute".
"Shape && Frequency" means, for this recompute, one tensor of size
(batch size, 500) will be saved because it will be recomputed.

**Baseline**

On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100
GPUs. With latest main branch, we can run batch size 16, and the maximum
batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB
memory is used during training. The SamplesPerSec=479.2543353561354.


![image](https://user-images.githubusercontent.com/10530022/188320941-13dde5e7-c32b-4399-a64b-6803fbb9dcda.png)

**With this PR**

Gelu is recomputed for saving memory peak, batch size 32 can be run. The
97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X**
of baseline).


![image](https://user-images.githubusercontent.com/10530022/188321081-f64811bf-9637-4873-8095-349de8d498cc.png)


**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
This commit is contained in:
pengwa 2022-11-03 13:49:41 +08:00 committed by GitHub
parent 77be22f379
commit a3e7da60e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 1485 additions and 5 deletions

90
docs/Memory_Optimizer.md Normal file
View file

@ -0,0 +1,90 @@
# Memory Optimizer for ONNX Runtime Training
## Introduction
ONNX Runtime Training provides a capability trading node/subgraph recomputations for better memory efficiency.
Specifically, a list of recomputable operators is pre-defined, with which memory optimizer graph transformer will iterate the graph to find all recomputable subgraph candidates.
When training with ORTModule, by default, the graph transformer will scan the execution graph to find all eligible subgraphs to recompute, along with sizes that can save. Users can pick up some of the subgraphs to enable them by environment variables.
## When memory optimizer can help?
Classical scenarios include:
- ORTModule run a model with batch size B (for example 2^N), the memory bandwidth and compute are not fully saturated, while it hits OOM to run a bigger batch size (for example 2^(N+1)).
- For big models, ORTModule fails to run the minimum allowed batch size, so performance can be compromised for a successful run.
Not all models and recipes need this optimizer technique. Imagine if your training recipe is using a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6.
## Quick trial
1. Make sure ONNX Runtime training wheel is installed and correctly configured.
2. Integrate models using ORTModule, be noted log_level should be equal or lower than INFO.
> ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.INFO))
3. Run the training as usual and redirect all outputs into log file; then stop it after training few steps.
4. Check the logging file, search "Summary", you could possibly find something like this:
```
MemoryOptimizer Summary:
User config:
=================================
########Recompute########
Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
--------------------------------
Subgraph: FastGelu+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24
=================================
########RecomputeWithCompromise########
Subgraph: Cast+Where+Softmax+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24
--------------------------------
=================================
```
5. As shown above, 'Subgraph' shows 1) a string representative for a recomputable subgraph; and 2) current status of memory optimization. All are disabled for recompute in this case.
6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In this sample, 12 FastGelu related subgraphs are allowed to recompute.
`FastGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `12` means the initial 12 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed.
```
export ORTMODULE_MEMORY_OPT_CONFIG="FastGelu+:1:12"
```
7. Then run the training again, you will see logs like this:
```
MemoryOptimizer Summary:
User config:
**FastGelu+:1:12**
=================================
########Recompute########
Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23
--------------------------------
Subgraph: FastGelu+
OptimizationType: **Recompute (requested_count=12, actual applied_count=12)**
Patterns:
PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24
=================================
########RecomputeWithCompromise########
Subgraph: Cast+Where+Softmax+
OptimizationType: Disabled
Patterns:
PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24
--------------------------------
=================================
```
8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well.
## Compromised Recompute
If you check the above logs, there is a separate section called "RecomputeWithCompromise". Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it.
## Notes
The feature is in experimental stage, we will tune and refine it according to real use cases.

View file

@ -61,6 +61,22 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
#ifdef ENABLE_TRAINING
// Specifies a list of op types for memory footprint reduction.
// The value should be a ","-delimited list of pair of
// <subgraph string : optimization strategy : number of subgraph to apply>.
// For example, "Gelu+Cast+:1:0,Dropout+:1:1".
// A valid "subgraph string" should be one subgraph representation output by ORT graph transformations.
// "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute.
// "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving"
// the memory.
static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.enable_memory_optimizer";
// Specifies the level for detecting subgraphs for memory footprint reduction.
// The value should be an integer. The default value is 0.
static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level";
#endif
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
// Using device allocators means the memory allocation is made using malloc/new.
static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
@ -81,9 +97,9 @@ static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "ses
/// <summary>
/// Key for using the ORT format model flatbuffer bytes directly for initializers.
/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
/// Requires `session.use_ort_model_bytes_directly` to be true.
/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
/// Requires `session.use_ort_model_bytes_directly` to be true.
/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
/// duration of the InferenceSession.
/// </summary>
static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =

View file

@ -112,7 +112,7 @@ void MemoryInfo::RecordActivationAllocInfo(const OrtValueIndex idx, const OrtVal
else if (map[MapType::Initializer].Contain(reuse_buffer))
map_type = MapType::Initializer;
else
std::cout << "Find no map type for reuse_buffer: " << reuse_buffer << ", so skipping" << std::endl;
LOGS_DEFAULT(VERBOSE) << "Find no map type for reuse_buffer: " << reuse_buffer << ", so skipping";
RecordTensorDeviceAllocInfo(idx, value, map_type);
}
@ -365,7 +365,7 @@ void MemoryProfiler::CreateEvents(const std::string& p_name,
void MemoryProfiler::GenerateMemoryProfile() {
// Write memory profile .json
std::stringstream ss;
ss << "memory_profile_" << GetMemoryInfo().GetLocalRank() << "_" << profiler_id_ << ".json";
ss << "memory_profile_" << GetMemoryInfo().GetLocalRank() << "_" << Env::Default().GetSelfPid() << "_" << profiler_id_ << ".json";
std::ofstream memory_profile(ss.str(), std::ios::trunc);
memory_profile << "[" << std::endl;
for (size_t i = 0; i < GetEvents().size(); i++) {

View file

@ -68,6 +68,7 @@
#ifdef ENABLE_TRAINING
#include "orttraining/core/optimizer/bitmask_dropout_replacement.h"
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
#include "orttraining/core/optimizer/memory_optimizer.h"
#include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h"
#endif
@ -297,6 +298,19 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
// The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their
// fusions might be prevented if this one removes a Q/DQ node too early.
transformers.emplace_back(std::make_unique<QDQFinalCleanupTransformer>(enable_quant_qdq_cleanup));
#ifdef ENABLE_TRAINING
// Put memory optimization transformer at last (which is done after most of fusions are done) by intention.
// Known issue: after mmeory optimization is completed, if some fusion happens, it is possible that the
// node priority got changed. This may disorder the execution order of nodes to recompute.
// TODO(pengwa): need to fix this issue.
const std::string enable_memory_optimizer =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerEnabler, "");
const std::string probe_level =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerProbeLevel, "0");
transformers.emplace_back(std::make_unique<MemoryOptimizer>(enable_memory_optimizer, probe_level));
#endif
} break;
case TransformerLevel::Level3: {
@ -315,6 +329,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
// while we can fuse more activation.
transformers.emplace_back(std::make_unique<ConvAddActivationFusion>(cpu_ep));
#endif
} break;
default:

Binary file not shown.

View file

@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""This file is used to generate test data for MemoryOptimizer tests in
onnxruntime/test/optimizer/memory_optimizer_test.cc.
Be noticed, after run this script, manually rename recompute_XXXX_execution_model_training.onnx to
recompute_XXXX.onnx
"""
import torch
from onnxruntime.training.ortmodule import DebugOptions, ORTModule
class LinearGeluLinearTest(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1):
out = self.fc1(input1)
out = torch.nn.functional.gelu(out)
out = self.fc2(out)
return out
DEVICE = "cuda"
def generate_gelu_test_case():
batch_size, dimension_in, hidden_size, dimension_out = 64, 784, 500, 10
model = LinearGeluLinearTest(dimension_in, hidden_size, dimension_out).to(DEVICE)
ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix="recompute_gelu"))
input = torch.randn(batch_size, dimension_in, device=DEVICE)
# Make sure model runs without any exception
prediction = ort_model(input)
assert prediction is not None
prediction = prediction.sum()
prediction.backward()
class TileTransposeLinearTest(torch.nn.Module):
def __init__(self, head):
super().__init__()
self._head = head
# input1 - float16[24,512,64]
# repeat - float16[4]
# query_layer - float16[24*labels_dim0,512,64]
def forward(self, input1, query_layer):
# Tile to [24*labels_dim0,512,64]
output = input1.repeat(query_layer.size(0) // self._head, 1, 1)
# Transpose to [24*labels_dim0,64,512]
output = output.permute(0, 2, 1).contiguous()
return torch.matmul(query_layer, output)
def generate_tile_test_case():
batch_size = 16
head = 24
seq_length = 512
model = TileTransposeLinearTest(head).to(DEVICE)
model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix="recompute_tile"))
input1 = torch.randn(head, seq_length, 64, device=DEVICE).requires_grad_(True)
query_layer = torch.randn(batch_size * head, seq_length, 64, device=DEVICE).requires_grad_(True)
# Make sure model runs without any exception
prediction = model(input1, query_layer)
assert prediction is not None
prediction = prediction.sum()
prediction.backward()
def main():
"""Main entry."""
generate_gelu_test_case()
generate_tile_test_case()
if __name__ == "__main__":
main()

Binary file not shown.

View file

@ -0,0 +1,785 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/random_seed.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"
#include "orttraining/core/graph/recompute_graph_utils.h"
#include "orttraining/core/optimizer/memory_optimizer.h"
namespace onnxruntime {
namespace {
constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15;
std::string TensorShapeProtoToString(const ONNX_NAMESPACE::TensorShapeProto* shape) {
std::ostringstream shape_oss;
if (shape != nullptr) {
for (int dim_index = 0; dim_index < shape->dim_size(); dim_index++) {
auto dim = shape->dim(dim_index);
if (utils::HasDimValue(dim)) {
shape_oss << dim.dim_value() << " x ";
} else {
shape_oss << dim.dim_param() << " x ";
}
}
} else {
shape_oss << "unknown";
}
return shape_oss.str();
}
int ParseIntValueFromString(std::string_view str) {
int int_value = 0;
auto result = std::from_chars(str.data(), str.data() + str.size(), int_value);
ORT_ENFORCE(result.ec != std::errc::invalid_argument, "Fail to convert to int from string: ", str);
return int_value;
}
bool IsForwardPassOperator(int64_t op_order_in_topological_sort, int64_t boundary_op_order_in_topological_sort) {
return op_order_in_topological_sort <= boundary_op_order_in_topological_sort;
}
static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) {
const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type);
MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto);
const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType();
ORT_ENFORCE(nullptr != tensor_type_base);
MLDataType elt_type = tensor_type_base->GetElementType();
return elt_type->Size();
}
// TODO(pengwa): extend this function to be more general.
float InputOutputSizeRatio(const Node* node) {
if (node->OpType().compare("Cast") == 0) {
const NodeArg* input = node->InputDefs()[0];
const NodeArg* output = node->OutputDefs()[0];
if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING ||
output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) {
return 1.0f;
}
const auto& ptype1 = input->Type();
const auto& ptype2 = output->Type();
float ratio = float(GetElementSize(ptype1)) / (float)GetElementSize(ptype2);
return ratio;
}
return 1.0f;
}
} // namespace
Status MemoryOptimizer::ParseConfigFromString(const std::string& enable_memory_optimizer,
const std::string& level) {
optimizer_config_ = enable_memory_optimizer;
if (!enable_memory_optimizer.empty()) {
const auto user_config_strs = utils::SplitString(enable_memory_optimizer, ",");
for (const auto& user_config_str : user_config_strs) {
const auto user_config = utils::SplitString(user_config_str, ":");
ORT_RETURN_IF_NOT(user_config.size() == 3,
"User config should be in format of SubgraphStr:OptimizationType:RequestApplyCount.");
const std::string subgraph_string_representation(user_config[0]);
int optimization_type_int = ParseIntValueFromString(user_config[1]);
int requested_apply_count = ParseIntValueFromString(user_config[2]);
ORT_RETURN_IF_NOT(optimization_type_int < static_cast<int>(OptimizationType::TypeMax) &&
optimization_type_int >= 0,
"Invalid optimization type specified for subgraph: ",
subgraph_string_representation);
ORT_RETURN_IF_NOT(requested_apply_count == -1 || requested_apply_count >= 0,
"Invalid requested_apply_count specified for subgraph: ", requested_apply_count);
// At this point, subgraph_string_representation is a pattern graph string representation.
pattern_subgraph_to_user_optimizer_config_map_[subgraph_string_representation] =
UserConfig{static_cast<OptimizationType>(optimization_type_int), requested_apply_count};
}
}
int probe_level = ParseIntValueFromString(level);
ORT_RETURN_IF_NOT(probe_level < static_cast<int>(ProbeLevel::LevelMax) && probe_level >= 0,
"Invalid probe level specified: ", level);
recompute_probe_level_ = static_cast<ProbeLevel>(probe_level);
return Status::OK();
}
int64_t MemoryOptimizer::PrepareForTransformation(const Graph& graph,
ActivationUsedMap& fw_op_output_arg_used_map,
InlinedHashMap<NodeIndex, size_t>&
node_index_to_its_order_in_topological_sort_map) const {
fw_op_output_arg_used_map.clear();
GraphViewer graph_viewer(graph);
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder();
// Find boundary ops between forward and backward pass, currently, it's limited to YieldOp.
int64_t yield_op_order_in_topological_sort = -1;
for (size_t i = 0; i < node_ids.size(); ++i) {
const Node* p_node = graph.GetNode(node_ids[i]);
if (p_node == nullptr) { /* skip removed nodes*/
continue;
}
if (p_node->OpType() == "YieldOp") {
yield_op_order_in_topological_sort = static_cast<int64_t>(i);
}
node_index_to_its_order_in_topological_sort_map[p_node->Index()] = i;
}
// If boundary op found, create forward op output arg used map.
if (yield_op_order_in_topological_sort >= 0) {
for (size_t i = 0; i < node_ids.size(); ++i) {
const Node* p_node = graph.GetNode(node_ids[i]);
if (p_node == nullptr /* skip removed nodes*/) {
continue;
}
const Node& node = *p_node;
bool is_forward_op = IsForwardPassOperator(static_cast<int64_t>(i), yield_op_order_in_topological_sort);
if (!is_forward_op) {
continue;
}
for (auto& output_arg : node.OutputDefs()) {
bool used_in_fw = false;
bool used_in_bw = false;
for (auto& consumer_node : graph.GetConsumerNodes(output_arg->Name())) {
auto consumer_node_index_in_topological_order =
node_index_to_its_order_in_topological_sort_map.at(consumer_node->Index());
if (IsForwardPassOperator(static_cast<int64_t>(consumer_node_index_in_topological_order),
yield_op_order_in_topological_sort)) {
used_in_fw = true;
} else {
used_in_bw = true;
}
}
fw_op_output_arg_used_map.insert({{output_arg->Name(), std::make_pair(used_in_fw, used_in_bw)}});
}
}
}
// Return whether boundary op is found or not.
return yield_op_order_in_topological_sort;
}
Status MemoryOptimizer::GetStashedActivationCandidates(const Graph& graph,
const InlinedHashMap<std::string, std::pair<bool, bool>>&
fw_op_output_arg_used_map,
InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
const logging::Logger& logger) const {
for (auto& kv : fw_op_output_arg_used_map) {
// used by fw and bw, then it is a candidates.
if (kv.second.first && kv.second.second) {
const Node* n = graph.GetProducerNode(kv.first);
ORT_ENFORCE(n, "Activation should have a producer node");
size_t k = 0;
for (k = 0; k < n->OutputDefs().size(); ++k) {
if (n->OutputDefs()[k]->Name().compare(kv.first) == 0) {
break;
}
}
candidate_output_args_map[n].push_back(k);
LOGS(logger, VERBOSE) << "Find candidate output named [" << kv.first << "] of Node " << n->Name() << "("
<< n->OpType() << ")";
}
}
return Status::OK();
}
bool MemoryOptimizer::ModifyGraph(Graph& graph,
const InlinedHashMap<NodeIndex, size_t>&
node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
const logging::Logger& logger,
int64_t boundary_op_order_in_topological_sort,
SubGraphStores& subgraph_stores,
Node* node) const {
bool graph_is_modified = false;
if (subgraph_stores.SubGraphDescCount() == 0) {
return graph_is_modified;
}
SubGraphStores::GraphInstanceInfo& sub_graph_instance_info =
subgraph_stores.GetSubGraphInstance(node);
SubGraphDesc& subgraph_desc = subgraph_stores.GetSubGraphDesc(sub_graph_instance_info.second);
UserConfig user_config = subgraph_desc.user_optimizer_config;
int skip_count = (user_config.requested_count == -1)
? 0
: std::max(0, subgraph_desc.total_frequency - user_config.requested_count);
subgraph_desc.skip_count += 1;
if (user_config.type != OptimizationType::None && subgraph_desc.skip_count > skip_count) {
subgraph_desc.applied_count += 1;
Node* replacement_node_ptr = nullptr;
LOGS(logger, WARNING) << "[Modify Graph] Node " << node->Name() << "(" << node->OpType() << ") is "
<< UserConfigToString(user_config);
if (user_config.type == OptimizationType::Recompute) {
ORT_ENFORCE(CreateRecomputeGraph(graph, sub_graph_instance_info.first, replacement_node_ptr).IsOK());
} else {
ORT_THROW("unsupported optimization type found: " + UserConfigToString(user_config));
}
ORT_ENFORCE(replacement_node_ptr);
graph_is_modified = true;
for (size_t output_index : candidate_output_args_map.at(node)) {
// Collect output edges (connecting to backward ops), to remove.
std::vector<graph_utils::GraphEdge> output_edges;
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
size_t src_output_idx = static_cast<size_t>(it->GetSrcArgIndex());
if (src_output_idx != output_index) {
continue;
}
auto tid = node_index_to_its_order_in_topological_sort_map.find(it->GetNode().Index());
// It is possible the consumer node is newly added as the recompute node, so we need a check here.
// For those kind of ops, we can treat them as backward ops.
if (tid == node_index_to_its_order_in_topological_sort_map.end() ||
!IsForwardPassOperator(node_index_to_its_order_in_topological_sort_map.at(tid->first),
boundary_op_order_in_topological_sort)) {
// Remove the edge only connecting to backward op.
output_edges.push_back(graph_utils::GraphEdge::CreateGraphEdge(*node, *it, false));
}
}
if (!output_edges.empty()) {
// Remove the output edges of the node first
graph_utils::GraphEdge::RemoveGraphEdges(graph, output_edges);
// Create connections between the replacement node and the outgoing nodes.
for (const auto& output_edge : output_edges) {
graph.RemoveConsumerNode(node->MutableOutputDefs()[output_index]->Name(), node);
// Add new edge connecting the input with the output nodes directly.
// This also updates the destination node's input node args
graph.AddEdge(replacement_node_ptr->Index(), output_edge.dst_node, static_cast<int>(output_index),
output_edge.dst_arg_index);
}
}
}
}
return graph_is_modified;
}
Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger)
const {
LOGS(logger, VERBOSE) << "Memory optimization config: " << optimizer_config_ << ", probe level: "
<< static_cast<int>(recompute_probe_level_);
InlinedHashMap<std::string, std::pair<bool, bool>> fw_op_output_arg_used_map;
InlinedHashMap<NodeIndex, size_t> node_index_to_its_order_in_topological_sort_map;
int64_t boundary_op_order_in_topological_sort =
PrepareForTransformation(graph, fw_op_output_arg_used_map,
node_index_to_its_order_in_topological_sort_map);
if (boundary_op_order_in_topological_sort < 0) {
LOGS(logger, VERBOSE) << "No boundary op found. Skip memory optimization.";
return Status::OK();
}
InlinedHashMap<const Node*, InlinedVector<size_t>> candidate_output_args_map;
ORT_RETURN_IF_ERROR(GetStashedActivationCandidates(graph, fw_op_output_arg_used_map, candidate_output_args_map,
logger));
SubGraphStores recompute_subgraph_stores;
SubGraphStores recompute_with_compromise_subgraph_stores;
GraphViewer graph_viewer(graph);
const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder();
// The first pass - find the candidate subgraphs.
for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
Node* p_node = graph.GetNode(node_ids[i]);
if (p_node == nullptr) {
continue;
}
if (candidate_output_args_map.find(p_node) == candidate_output_args_map.end()) {
continue;
}
bool can_compromise_stashed_activation = false;
CheckNodeForRecompute(*p_node, fw_op_output_arg_used_map,
node_index_to_its_order_in_topological_sort_map,
candidate_output_args_map,
recompute_subgraph_stores, logger, false,
can_compromise_stashed_activation);
if (can_compromise_stashed_activation) {
LOGS(logger, VERBOSE) << "Searching Node " << p_node->Name() << "(" << p_node->OpType()
<< ") for compromised recompute";
// If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist
// during backward pass, then we can try to compromise the assumption.
CheckNodeForRecompute(*p_node, fw_op_output_arg_used_map, node_index_to_its_order_in_topological_sort_map,
candidate_output_args_map,
recompute_with_compromise_subgraph_stores, logger, true,
can_compromise_stashed_activation);
}
}
// The second pass - apply the transformation.
// Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated.
// The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended
// earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier
// layers.
for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
Node* p_node = graph.GetNode(node_ids[i]);
if (p_node == nullptr) {
continue;
}
bool has_been_modified = false;
if (recompute_subgraph_stores.ContainsSubGraphInstance(p_node)) {
has_been_modified = ModifyGraph(graph, node_index_to_its_order_in_topological_sort_map,
candidate_output_args_map, logger,
boundary_op_order_in_topological_sort,
recompute_subgraph_stores, p_node);
}
// If there are other recompute plan for this node, we skip them because the graph is already modified.
if (!has_been_modified && recompute_with_compromise_subgraph_stores.ContainsSubGraphInstance(p_node)) {
has_been_modified = ModifyGraph(graph, node_index_to_its_order_in_topological_sort_map,
candidate_output_args_map, logger,
boundary_op_order_in_topological_sort,
recompute_with_compromise_subgraph_stores, p_node);
}
modified = modified || has_been_modified;
}
PrintSummary(recompute_subgraph_stores, recompute_with_compromise_subgraph_stores, logger);
return Status::OK();
}
void MemoryOptimizer::NodesInTopoOrderToString(const InlinedVector<const Node*>& nodes_in_topological_order,
std::string& subgraph_string_representation,
std::string& log_info) const {
std::ostringstream oss;
std::ostringstream subgraph_string_representation_oss;
size_t node_count = nodes_in_topological_order.size();
for (size_t i = 0; i < node_count; ++i) {
if (i < node_count - 1) { // Ignore the last node.
oss << "(name:" << nodes_in_topological_order[i]->Name() << ", type:" << nodes_in_topological_order[i]->OpType()
<< "),";
}
subgraph_string_representation_oss << nodes_in_topological_order[i]->OpType() << "+";
}
subgraph_string_representation = subgraph_string_representation_oss.str();
log_info = oss.str();
if (log_info.size() > 0) {
log_info = " with its precedent nodes: " + log_info;
}
}
std::string MemoryOptimizer::UserConfigToString(const UserConfig& config) const {
std::string type_str;
switch (config.type) {
case OptimizationType::None: {
type_str = "Disabled";
} break;
case OptimizationType::Recompute: {
type_str = "Recomputed";
} break;
default: {
type_str = "Unknown";
} break;
}
return type_str;
}
void MemoryOptimizer::PrintSummary(const SubGraphStores& recompute_stores,
const SubGraphStores& recompute_with_compromise_stores,
const logging::Logger& logger) const {
if (recompute_stores.SubGraphDescCount() == 0 && recompute_with_compromise_stores.SubGraphDescCount() == 0) {
return;
}
std::ostringstream summary;
summary << "\nMemoryOptimizer Summary:\n";
summary << "\tUser config:\n\t" << optimizer_config_ << "\n";
summary << "\t=================================\n";
auto print_info_from_stores = [&summary, this](std::string store_name, const SubGraphStores& stores) {
summary << "\t########" << store_name << "########\n";
for (auto subgraph_it = stores.subgraph_descs.begin(); subgraph_it != stores.subgraph_descs.end();
++subgraph_it) {
std::string freq_info;
if (subgraph_it->second.user_optimizer_config.type != OptimizationType::None)
freq_info = " (requested_count=" + std::to_string(subgraph_it->second.user_optimizer_config.requested_count) +
", actual applied_count=" +
std::to_string(subgraph_it->second.applied_count) + ")";
summary << "\tSubgraph: " << subgraph_it->first << "\n"
<< "\t\tOptimizationType: "
<< UserConfigToString(subgraph_it->second.user_optimizer_config) << freq_info << "\n"
<< "\t\tPatterns: \n";
for (auto shape_stat_it = subgraph_it->second.shape_str_frequency.begin();
shape_stat_it != subgraph_it->second.shape_str_frequency.end();
++shape_stat_it) {
summary << "\t\t\tPatternShape:" << shape_stat_it->first << "\tFrequency:" << shape_stat_it->second << "\n";
}
summary << "\t--------------------------------\n";
}
summary << "\t=================================\n";
};
print_info_from_stores("Recompute", recompute_stores);
print_info_from_stores("RecomputeWithCompromise", recompute_with_compromise_stores);
LOGS(logger, INFO) << summary.str() << "\n";
}
/******************************************************
** Recompute related function implementation starts **
******************************************************/
void MemoryOptimizer::RegisterAllowedRecomputeOps() {
if (static_cast<int>(recompute_probe_level_) >= static_cast<int>(ProbeLevel::Basic)) {
recomputable_op_type_to_input_arg_index_map_.insert({
// Binary elementwise
{"Add", AllowedRecomputeNodeConfig{{0, 1}}},
{"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}},
{"Div", AllowedRecomputeNodeConfig{{0, 1}}},
{"Mul", AllowedRecomputeNodeConfig{{0, 1}}},
{"Sub", AllowedRecomputeNodeConfig{{0, 1}}},
// Data layout
/// The shape input is trivial whether it exists or not in backward.
{"Reshape", AllowedRecomputeNodeConfig{{0}}},
{"Squeeze", AllowedRecomputeNodeConfig{{0}}},
{"Unsqueeze", AllowedRecomputeNodeConfig{{0}}},
// Unary elementwise
/// The ratio and mode input are trivial whether they exist or not in backward
{"BitmaskDropout", AllowedRecomputeNodeConfig{{0}}},
/// The axis input is trivial whether it exists or not in backward
{"CumSum", AllowedRecomputeNodeConfig{{0}}},
{"Dropout", AllowedRecomputeNodeConfig{{0}}},
{"Gelu", AllowedRecomputeNodeConfig{{0}}},
{"FastGelu", AllowedRecomputeNodeConfig{{0}}},
// Ternary elementwise
{"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}},
// Data copy
{"Tile", AllowedRecomputeNodeConfig{{0}}},
{"Cast", AllowedRecomputeNodeConfig{{0}}},
});
}
if (static_cast<int>(recompute_probe_level_) >= static_cast<int>(ProbeLevel::Advanced)) {
recomputable_op_type_to_input_arg_index_map_.insert({
{"MatMul", AllowedRecomputeNodeConfig{{0, 1}}},
{"FusedMatMul", AllowedRecomputeNodeConfig{{0, 1}}},
{"Softmax", AllowedRecomputeNodeConfig{{0}}},
{"BiasSoftmax", AllowedRecomputeNodeConfig{{0, 1}}},
{"BiasSoftmaxDropout", AllowedRecomputeNodeConfig{{0, 1}}},
});
}
}
Status MemoryOptimizer::SelectRecomputeSubgraph(const Node& node,
const InlinedVector<size_t>& node_output_index_candidates,
const ActivationUsedMap& fw_op_output_arg_used_map,
const InlinedHashMap<NodeIndex, size_t>&
node_index_to_its_order_in_topological_sort_map,
InlinedVector<const Node*>& nodes,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation) const {
can_compromise_stashed_activation = false;
LOGS(logger, VERBOSE) << "Enter SelectRecomputeSubgraph for Node " << node.Name() << "(" << node.OpType() << ")";
nodes.clear();
std::deque<NodeOutputPort> q;
for (auto output_index : node_output_index_candidates) {
q.push_back(NodeOutputPort(&node, static_cast<int>(output_index)));
}
bool early_stop = false;
std::set<NodeOutputPort> visited_output_arg_set;
std::set<const Node*> visited_node_set;
// For the initial activations in queue, they are stashed ones, so we do differently when scan the queue for them.
bool is_first_queue_scan = true;
while (nodes.size() < MAXIMUM_RECOMPUTE_NODE_COUNT && !q.empty() && !early_stop) {
// Loop all candidate NodeOutputPort, and find the next layer of input nodes.
size_t current_queue_size = q.size();
for (size_t i = 0; i < current_queue_size; ++i) {
NodeOutputPort p = q.front();
q.pop_front();
const Node* curr_node = p.first;
// Skip if the node output is already visited.
if (std::find(visited_output_arg_set.begin(), visited_output_arg_set.end(), p) !=
visited_output_arg_set.end()) {
continue;
}
visited_output_arg_set.insert({p});
// If the node already visited by from it's other output index, skip it.
if (visited_node_set.find(curr_node) != visited_node_set.end()) {
continue;
}
visited_node_set.insert(curr_node);
// Bottom-up search rules.
// If current op is entry output node (that generates stashed activations):
// 1. If the op is not in recomputable_op_type_to_input_arg_index_map_, skip it.
// Otherwise:
// If current op is in allowed list, check its input args, and append the producers' NodeOutputPorts to next_q.
// If current op is NOT in allowed list:
// 1). the output does not exist in backward, we cannot find a good solution for so, search terminates.
// 2). the output is used in backward, we don't need trace back further, continue searching.
auto op_recompute_config_it = recomputable_op_type_to_input_arg_index_map_.find(curr_node->OpType());
auto cur_output_arg_name = curr_node->OutputDefs()[p.second]->Name();
if (is_first_queue_scan) {
// We handle the entry node outputs differently because, we don't want this case falls into and succeed one of
// the checks in the other branch
// 1. "op is not in recompute op list, but its output is used in backward"
// 2. "op is in recompute op list, but its output is used in backward"
// (either of the above checks is true for entry node outputs)
if (op_recompute_config_it == recomputable_op_type_to_input_arg_index_map_.end()) {
early_stop = true;
LOGS(logger, VERBOSE) << "Entry Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** "
<< "in recompute op list, search terminates.";
break;
}
} else {
if (op_recompute_config_it == recomputable_op_type_to_input_arg_index_map_.end()) {
if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) {
LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in "
<< "recompute op list, but its output [" << cur_output_arg_name
<< "] is used in backward, we don't need trace bottom-up further";
continue;
} else {
early_stop = true;
LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in "
<< "recompute op list, and its output [" << cur_output_arg_name
<< "] does not exist in backward, search terminates.";
break;
}
}
if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) {
LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") "
<< "is in recompute op list, while its output [" << cur_output_arg_name
<< "] is used in backward, we don't need trace bottom-up further";
continue;
}
}
// Append node to the selected graph.
if (std::find(nodes.begin(), nodes.end(), curr_node) == nodes.end()) {
nodes.push_back(curr_node);
LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType()
<< ") is added in selected subgraph ";
}
// This check is not matured now, subject to be changed.
float ratio = InputOutputSizeRatio(curr_node);
float is_current_node_compromisable = (ratio < 1.f);
can_compromise_stashed_activation = can_compromise_stashed_activation || is_current_node_compromisable;
if (is_current_node_compromisable) {
LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType()
<< ") has input/output size " << ratio << " < 1.f, can compromise stashed activation";
}
if (is_current_node_compromisable && compromise_stashed_activation) {
LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is in "
<< "recompute op list, and its output [" << cur_output_arg_name
<< "] does not exist in backward, while it meet compromised check, we don't need trace "
<< "bottom-up further.";
continue;
}
// Iterate all input nodes according to allowed input arg index of the entry node.
const auto& input_arg_indices = op_recompute_config_it->second.input_arg_indices;
for (auto it = curr_node->InputEdgesBegin(), end = curr_node->InputEdgesEnd(); it != end; ++it) {
const Node::EdgeEnd& input_edge = *it;
const auto& parent_node = input_edge.GetNode();
const auto parent_node_output_index = input_edge.GetSrcArgIndex();
const auto current_node_input_index = input_edge.GetDstArgIndex();
if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) !=
input_arg_indices.end()) {
NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index);
LOGS(logger, VERBOSE) << "Node " << parent_node.Name() << "(" << parent_node.OpType() << ")'s "
<< parent_node_output_index
<< "th output [" << parent_node.OutputDefs()[parent_node_output_index]->Name()
<< "] is added in recompute search list ";
q.push_back(next_p);
}
}
}
// After handle all entry node outputs, we set the flag to false.
is_first_queue_scan = false;
}
// If input args are not found in bw, but op count exceed MAXIMUM_RECOMPUTE_NODE_COUNT, skip recompute.
if (!q.empty() || early_stop) {
LOGS(logger, VERBOSE) << "Fail to find a solution for recompute: current node count is " << nodes.size()
<< ", queue size: " << q.size() << ", early stop: " << early_stop;
nodes.clear();
} else {
// Re-order the nodes in topological order.
std::sort(nodes.begin(), nodes.end(),
[&node_index_to_its_order_in_topological_sort_map](const Node*& lhs, const Node*& rhs) {
return node_index_to_its_order_in_topological_sort_map.at(lhs->Index()) <
node_index_to_its_order_in_topological_sort_map.at(rhs->Index());
});
}
return Status::OK();
}
void MemoryOptimizer::CheckNodeForRecompute(const Node& node,
const ActivationUsedMap& fw_op_output_arg_used_map,
const InlinedHashMap<NodeIndex, size_t>&
node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
SubGraphStores& subgraph_stores,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation) const {
if (recomputable_op_type_to_input_arg_index_map_.find(node.OpType()) ==
recomputable_op_type_to_input_arg_index_map_.end()) {
return;
}
InlinedVector<const Node*> nodes_in_topological_order;
ORT_ENFORCE(SelectRecomputeSubgraph(node, candidate_output_args_map.at(&node),
fw_op_output_arg_used_map,
node_index_to_its_order_in_topological_sort_map,
nodes_in_topological_order, logger,
compromise_stashed_activation,
can_compromise_stashed_activation)
.IsOK());
if (nodes_in_topological_order.size() == 0) {
return;
}
std::string subgraph_str_representation, log_info;
NodesInTopoOrderToString(nodes_in_topological_order, subgraph_str_representation, log_info);
LOGS(logger, VERBOSE) << "Node " << node.Name() << "(" << node.OpType() << ") can be recomputed" << log_info;
// Update the subgraph optimization config map - key is the subgraph string representation, value is user config.
UserConfig user_config{OptimizationType::None, 0};
if (pattern_subgraph_to_user_optimizer_config_map_.find(subgraph_str_representation) !=
pattern_subgraph_to_user_optimizer_config_map_.end()) {
user_config = pattern_subgraph_to_user_optimizer_config_map_.at(subgraph_str_representation);
}
SubGraphDesc& subgraph_desc =
subgraph_stores.Contains(subgraph_str_representation)
? subgraph_stores.GetSubGraphDesc(subgraph_str_representation)
: subgraph_stores.CreateSubGraphDesc(subgraph_str_representation, user_config);
subgraph_desc.total_frequency += 1;
// Update the subgraph frequency map - key is the subgraph string representation, value is number of appearances.
for (size_t output_index : candidate_output_args_map.at(&node)) {
auto shape_str = TensorShapeProtoToString(node.OutputDefs()[output_index]->Shape());
subgraph_desc.shape_str_frequency[shape_str]++;
}
subgraph_stores.AddSubGraphInstance(&node, nodes_in_topological_order, subgraph_desc);
return;
}
Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph,
const InlinedVector<const Node*>& nodes_in_topological_order,
Node*& new_output_node_ptr) const {
InlinedHashMap<NodeArg*, NodeArg*> self_contained_outputs_map;
for (size_t i = 0; i < nodes_in_topological_order.size(); ++i) {
Node* node_to_duplicate = graph.GetNode(nodes_in_topological_order[i]->Index());
// Check whether the node has been recomputed/offloaded or not. Simply check the existence of the first output
// of the node has its corresponding recompute name or not.
// TODO: if there is more optimization types like offload added, we will add corresponding check whether the outputs
// already be offloaded or not.
if (graph.GetNodeArg(graph_utils::RecomputeName(node_to_duplicate->MutableOutputDefs()[0]->Name())) != nullptr) {
continue;
}
InlinedVector<NodeArg*> new_input_args;
new_input_args.reserve(node_to_duplicate->MutableInputDefs().size());
for (NodeArg* input_arg : node_to_duplicate->MutableInputDefs()) {
if (self_contained_outputs_map.find(input_arg) == self_contained_outputs_map.end()) {
NodeArg* recompute_input_arg = graph.GetNodeArg(graph_utils::RecomputeName(input_arg->Name()));
new_input_args.push_back(recompute_input_arg ? recompute_input_arg : input_arg);
} else {
new_input_args.push_back(self_contained_outputs_map[input_arg]);
}
}
InlinedVector<NodeArg*> new_output_args;
new_output_args.reserve(node_to_duplicate->MutableOutputDefs().size());
for (size_t k = 0; k < node_to_duplicate->MutableOutputDefs().size(); ++k) {
const auto& output = node_to_duplicate->MutableOutputDefs()[k];
new_output_args.push_back(&graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
output->TypeAsProto()));
self_contained_outputs_map[output] = new_output_args.back();
}
Node& recompute_node = graph.AddNode(node_to_duplicate->Name() + "_recompute",
node_to_duplicate->OpType(),
"Recompute of " + node_to_duplicate->Name(),
new_input_args,
new_output_args,
&node_to_duplicate->GetAttributes(),
node_to_duplicate->Domain());
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
recompute_node.SetExecutionProviderType(node_to_duplicate->GetExecutionProviderType());
ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(recompute_node),
"Failed to set op schema for added recompute node.");
new_output_node_ptr = &recompute_node;
for (size_t j = 0; j < recompute_node.MutableOutputDefs().size(); ++j) {
graph.UpdateProducerNode(recompute_node.MutableOutputDefs()[j]->Name(), recompute_node.Index());
}
// Add the edges from the recompute node to the original node.
for (size_t j = 0; j < recompute_node.MutableInputDefs().size(); ++j) {
NodeArg* input_arg = recompute_node.MutableInputDefs()[j];
const Node* producer_node = graph.GetProducerNode(input_arg->Name());
if (producer_node == nullptr) {
// Skip when it is graph input or initializer.
continue;
}
int producer_output_index = optimizer_utils::IndexOfNodeOutput(*producer_node, *input_arg);
graph.AddEdge(producer_node->Index(), recompute_node.Index(), static_cast<int>(producer_output_index),
static_cast<int>(j));
graph.AddConsumerNode(input_arg->Name(), &recompute_node);
}
}
return Status::OK();
}
/******************************************************
** Recompute related function implementation ends **
******************************************************/
} // namespace onnxruntime

View file

@ -0,0 +1,334 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <charconv>
#include "core/common/inlined_containers.h"
#include "core/common/string_utils.h"
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
@Class MemoryOptimizer
Find recomputable subgraphs and enable according to user configs.
*/
class MemoryOptimizer : public GraphTransformer {
private:
using NodeOutputPort = std::pair<const Node*, int>;
using ActivationUsedMap = InlinedHashMap<std::string, std::pair<bool, bool>>;
/**
* @brief Level to control allowed operations during subgraph detecting.
* Level 0: only allow cheap-to-compute operations.
* Level 1: allow more expensive operations.
*/
enum class ProbeLevel {
Basic = 0,
Advanced = 1,
LevelMax = 2,
};
/**
* @brief Type of memory reduction techniques.
*/
enum class OptimizationType {
None = 0, // Disabled.
Recompute = 1,
TypeMax = 2,
};
/**
* @brief Type of user config.
* type: type of memory reduction techniques.
* requested_count: the number of occurrences of a subgraph pattern for alleviation. -1 means apply all.
* One example: if a subgraph pattern is found 3 times, and requested_count is set 2, then the 1st and 2nd subgraph
* in topological order will be applied for alleviation. This is useful to avoid alleviating more memory than
* needed.
*/
struct UserConfig {
OptimizationType type;
int requested_count;
};
/**
* @brief Struct to store properties of a specific subgraph.
*/
struct SubGraphDesc {
SubGraphDesc() = default;
// A string to represent the subgraph, used as a unique "ID" for a unique subgraph.
std::string subgraph_representative_str;
InlinedHashMap<std::string, int> shape_str_frequency; // shape string to frequency
UserConfig user_optimizer_config;
int total_frequency{0}; // The occurrence of this subgraph pattern in the graph.
int applied_count{0}; // The number of times this subgraph pattern has been really applied in this transformer.
int skip_count{0}; // The number of times this subgraph instances will skipped in reversed topological order.
float saving_ratio{1.0f};
};
/**
* @brief A struct to maintain the information of target subgraphs to optimize.
* Imagine we loop all nodes finding recomputable/offload-able subgraphs, we want to store them first.
* Afterwards, we optionally pick up some of them to apply optimization according to user configs.
*
* subgraph_descs is a map from subgraph string representation to its subgraph related configurations.
*
* _optimization_target_graphs_ is a map from activation producer node pointers to its target optimization subgraph
* nodes. For example, if a subgraph Cast+Gelu can be recomputed, we may have a map like:
* key: node pointer of stashed activation producer Gelu; value: node vector {Cast, Gelu,}.
*
* When we AddSubGraphInstance, we must provider its corresponding subgraph desc in the parameter.
* Then we can know for each subgraph instance, what's the subgraph str representation, and what's the optimization
* config.
*/
struct SubGraphStores {
/**********************************
** subgraph desc section starts **
**********************************/
size_t SubGraphDescCount() const {
return subgraph_descs.size();
}
bool Contains(std::string_view subgraph_str) const {
return subgraph_descs.find(subgraph_str) != subgraph_descs.end();
}
SubGraphDesc& GetSubGraphDesc(std::string_view subgraph_string) {
ORT_ENFORCE(Contains(subgraph_string), "Subgraph string not found.", subgraph_string);
return subgraph_descs.at(subgraph_string);
}
SubGraphDesc& CreateSubGraphDesc(const std::string& subgraph_string,
UserConfig& config) {
ORT_ENFORCE(!Contains(subgraph_string), "Subgraph string already exists.", subgraph_string);
subgraph_descs[subgraph_string].user_optimizer_config = config;
subgraph_descs[subgraph_string].subgraph_representative_str = subgraph_string;
return subgraph_descs[subgraph_string];
}
/**********************************************************************
** subgraph desc section ends, and subgraph instance section starts. **
***********************************************************************/
// Pair of <nodes in topological order, a string to represent the subgraph>.
using GraphInstanceInfo = std::pair<InlinedVector<const Node*>, std::string>;
void AddSubGraphInstance(const Node* node,
const InlinedVector<const Node*>& nodes_in_topological_order,
const SubGraphDesc& subgraph_desc) {
ORT_ENFORCE(_optimization_target_graphs_.find(node) == _optimization_target_graphs_.end());
_optimization_target_graphs_[node] = std::make_pair(nodes_in_topological_order,
subgraph_desc.subgraph_representative_str);
}
bool ContainsSubGraphInstance(const Node* node) const {
return _optimization_target_graphs_.find(node) != _optimization_target_graphs_.end();
}
GraphInstanceInfo& GetSubGraphInstance(const Node* node) {
ORT_ENFORCE(_optimization_target_graphs_.find(node) != _optimization_target_graphs_.end());
return _optimization_target_graphs_[node];
}
/***********************************
** subgraph instance section ends **
***********************************/
InlinedHashMap<std::string /*subgraph_representative_str*/, SubGraphDesc> subgraph_descs;
InlinedHashMap<const Node*, GraphInstanceInfo> _optimization_target_graphs_;
};
/**
* @brief Used to define per-op recompute config.
*
*/
struct AllowedRecomputeNodeConfig {
InlinedVector<int> input_arg_indices; // input index to iterate further (bottom up)
};
public:
MemoryOptimizer(const std::string& enable_memory_optimizer, const std::string& level)
: GraphTransformer("MemoryOptimizer") {
// Parse user defined configs.
ORT_ENFORCE(ParseConfigFromString(enable_memory_optimizer, level).IsOK());
RegisterAllowedRecomputeOps();
}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
bool ShouldOnlyApplyOnce() const override { return true; }
private:
Status ParseConfigFromString(const std::string& enable_memory_optimizer, const std::string& level);
/**
* @brief Prepare info including activation usage, node usage in fw and bw.
*
* @param graph Graph to iterate.
* @param fw_op_output_arg_used_map Collected activation usage mapping.
* - key: node arg name
* - value: a pair of bool, representing whether the activation is used by forward nodes or by backward nodes.
* @return int64_t value The boundary op (for example YieldOp) order in topological order. If no boundary op found,
* return -1;
*/
int64_t PrepareForTransformation(const Graph& graph,
ActivationUsedMap& fw_op_output_arg_used_map,
InlinedHashMap<NodeIndex, size_t>&
node_index_to_its_order_in_topological_sort_map) const;
/**
* @brief Find all stashed activations, e.g. activations used by forward operators and backward operators.
*
* @param graph Graph to iterate.
* @param fw_op_output_arg_used_map Activation usage mapping.
* @param candidate_output_args_map Candidate activations, which are consumed by both fw and bw ops.
* @return Status
*/
Status GetStashedActivationCandidates(
const Graph& graph,
const InlinedHashMap<std::string, std::pair<bool, bool>>& fw_op_output_arg_used_map,
InlinedHashMap<const Node*, InlinedVector<size_t>>& candidate_output_args_map,
const logging::Logger& logger) const;
/**
* @brief Apply graph modifications based on user configs.
*
* @param graph Graph to iterate and modify.
* @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort.
* Used to re-order the collected subgraph nodes.
* @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and
* bw ops.
* @param logger Logger.
* @param boundary_op_order_in_topological_sort index of the boundary op between fw and bw.
* @param subgraph_stores A store to maintain all found subgraphs.
* @param node The node we used to look for corresponding optimization graphs.
* @return true
* @return false
*/
bool ModifyGraph(Graph& graph,
const InlinedHashMap<NodeIndex, size_t>& node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>& candidate_output_args_map,
const logging::Logger& logger,
int64_t boundary_op_order_in_topological_sort,
SubGraphStores& subgraph_stores,
Node* node) const;
/**
* @brief Convert the recompute subgraph to its string representation.
*
* @param nodes_in_topological_order The subgraph nodes in topological order.
* @param subgraph_string_representation Returns subgraph string representation.
* @param log_info Returns log info for users.
*/
void NodesInTopoOrderToString(const InlinedVector<const Node*>& nodes_in_topological_order,
std::string& subgraph_string_representation,
std::string& log_info) const;
/**
* @brief Convert optimization type to string.
*/
std::string UserConfigToString(const UserConfig& config) const;
/**
* @brief Summarize transformation details.
*
* @param stashed_activation_statistics statistics around stashed activation memory saving.
* @return void
*/
void PrintSummary(const SubGraphStores& recompute_stores,
const SubGraphStores& recompute_with_compromise_stores,
const logging::Logger& logger) const;
/**************************************************
** Recompute related function definition starts **
*************************************************/
void RegisterAllowedRecomputeOps();
/**
* @brief Find recomputable subgraphs (has at least one nodes, at most MAXIMUM_RECOMPUTE_NODE_COUNT nodes).
*
* @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs.
* @param node_output_index_candidates Candidate output indices of "node", which are consumed by both fw and bw ops.
* @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping.
* @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort.
* Used to re-order the collected subgraph nodes.
* @param nodes_in_topological_order Collected vector of nodes of found subgraph, in the order of the topological
* sorted.
* @param logger Logger.
* @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a
* recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the
* size of stashed activation.
* @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a
* compromised subgraph.
* @return Status
*/
Status SelectRecomputeSubgraph(const Node& node,
const InlinedVector<size_t>& node_output_index_candidates,
const ActivationUsedMap& fw_op_output_arg_used_map,
const InlinedHashMap<NodeIndex, size_t>&
node_index_to_its_order_in_topological_sort_map,
InlinedVector<const Node*>& nodes_in_topological_order,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation) const;
/**
* @brief For the node producing stashed activation, check whether a recomputable subgraph can be found or not.
*
* @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs.
* @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping.
* @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort.
* Used to re-order the collected subgraph nodes.
* @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and
* bw ops.
* @param subgraph_stores A store to maintain all found subgraphs.
* @param logger Logger.
* @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a
* recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the
* size of stashed activation.
* @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a
* compromised subgraph.
*/
void CheckNodeForRecompute(const Node& node,
const ActivationUsedMap& fw_op_output_arg_used_map,
const InlinedHashMap<NodeIndex, size_t>&
node_index_to_its_order_in_topological_sort_map,
const InlinedHashMap<const Node*, InlinedVector<size_t>>&
candidate_output_args_map,
SubGraphStores& subgraph_stores,
const logging::Logger& logger,
bool compromise_stashed_activation,
bool& can_compromise_stashed_activation) const;
/**
* @brief Duplicate nodes to create a recompute subgraph.
*
* @param graph Graph to iterate.
* @param nodes_in_topological_order Subgraph nodes to recompute.
* @param recompute_subgraph_output_node The final node of the subgraph.
* @return Status
*/
Status CreateRecomputeGraph(Graph& graph,
const InlinedVector<const Node*>& nodes_in_topological_order,
Node*& recompute_subgraph_output_node) const;
/**************************************************
** Recompute related function definition ends **
*************************************************/
// The op types that are supported predefined.
InlinedHashMap<std::string, AllowedRecomputeNodeConfig> recomputable_op_type_to_input_arg_index_map_;
// User enabled map of the subgraph string representation to the alleviation type.
InlinedHashMap<std::string, UserConfig> pattern_subgraph_to_user_optimizer_config_map_;
std::string optimizer_config_;
ProbeLevel recompute_probe_level_;
};
} // namespace onnxruntime

View file

@ -296,6 +296,11 @@ class GraphExecutionManager(GraphExecutionInterface):
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
session_options.log_severity_level = int(self._debug_options.logging.log_level)
# Disable memory alleviation by default. Allow user to enable it via environment variable.
alleviation_config = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True)
probe_level = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_PROBE_RECOMPUTE_LEVEL", "1", warn=True)
session_options.add_session_config_entry("optimization.enable_memory_optimizer", alleviation_config)
session_options.add_session_config_entry("optimization.enable_memory_probe_recompute_level", probe_level)
if self._debug_options.save_onnx_models.save:
session_options.optimized_model_filepath = os.path.join(

View file

@ -0,0 +1,147 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4244)
#endif
#include <random>
#include "core/graph/onnx_protobuf.h"
#include "gtest/gtest.h"
#include "gmock/gmock.h"
#include "asserts.h"
#include "core/common/span_utils.h"
#include "core/framework/data_types.h"
#include "core/framework/ort_value.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/optimizer/utils.h"
#include "core/platform/env.h"
#include "core/session/inference_session.h"
#include "core/util/math.h"
#include "test/framework/test_utils.h"
#include "test/capturing_sink.h"
#include "test/test_environment.h"
#include "test/util/include/asserts.h"
#include "orttraining/core/optimizer/memory_optimizer.h"
using namespace std;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {
#define MODEL_FOLDER ORT_TSTR("testdata/transform/recompute/")
TEST(MemoryOptimizerTests, GeluRecompute) {
const logging::Logger* logger = &logging::LoggingManager::DefaultLogger();
auto model_uri = MODEL_FOLDER "recompute_gelu.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Gemm"] == 5);
ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 1);
ASSERT_TRUE(op_to_count["com.microsoft.YieldOp"] == 1);
ASSERT_TRUE(op_to_count["ReduceSum"] == 2);
ASSERT_TRUE(op_to_count["com.microsoft.GeluGrad"] == 1);
std::string gelu_node_name;
for (auto& node : graph.Nodes()) {
if (node.OpType().compare("Gelu") == 0) {
gelu_node_name = node.Name();
break;
}
}
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const std::string alleviation_config("Gelu+:1:-1");
const std::string alleviation_level("1");
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<MemoryOptimizer>(alleviation_config, alleviation_level), TransformerLevel::Level3));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Gemm"] == 5);
ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 2);
ASSERT_TRUE(op_to_count["com.microsoft.YieldOp"] == 1);
ASSERT_TRUE(op_to_count["ReduceSum"] == 2);
ASSERT_TRUE(op_to_count["com.microsoft.GeluGrad"] == 1);
Node* recompute_gelu_node{nullptr};
Node* original_gelu_node{nullptr};
for (auto& node : graph.Nodes()) {
if (node.OpType().compare("Gelu") == 0) {
if (node.Name() != gelu_node_name) {
recompute_gelu_node = &node;
} else {
original_gelu_node = &node;
}
}
}
ASSERT_EQ(recompute_gelu_node->MutableInputDefs()[0]->Name(), original_gelu_node->MutableInputDefs()[0]->Name());
ASSERT_EQ(recompute_gelu_node->Priority(), static_cast<int>(ExecutionPriority::LOCAL_LOW));
ASSERT_EQ(original_gelu_node->Priority(), static_cast<int>(ExecutionPriority::DEFAULT));
}
TEST(MemoryOptimizerTests, TileRecompute) {
const logging::Logger* logger = &logging::LoggingManager::DefaultLogger();
auto model_uri = MODEL_FOLDER "recompute_tile.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Tile"] == 1);
ASSERT_TRUE(op_to_count["com.microsoft.YieldOp"] == 1);
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 3);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const std::string alleviation_config("Tile+:1:-1");
const std::string alleviation_level("1");
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<MemoryOptimizer>(alleviation_config, alleviation_level), TransformerLevel::Level3));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Tile"] == 2);
ASSERT_TRUE(op_to_count["com.microsoft.YieldOp"] == 1);
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 3);
Node* recompute_tile_node{nullptr};
Node* original_tile_node{nullptr};
for (auto& node : graph.Nodes()) {
if (node.Priority() == static_cast<int>(ExecutionPriority::LOCAL_LOW)) {
if (node.OpType().compare("Tile") == 0) {
recompute_tile_node = &node;
}
} else if (node.Priority() == static_cast<int>(ExecutionPriority::DEFAULT)) {
if (node.OpType().compare("Tile") == 0) {
original_tile_node = &node;
}
}
}
const Node* query_layer_grad_node = graph.GetProducerNode("query_layer_grad");
ASSERT_TRUE(recompute_tile_node);
ASSERT_TRUE(original_tile_node);
ASSERT_TRUE(query_layer_grad_node);
ASSERT_EQ(recompute_tile_node->MutableInputDefs()[0]->Name(), original_tile_node->MutableInputDefs()[0]->Name());
ASSERT_EQ(query_layer_grad_node->InputDefs()[1]->Name(), recompute_tile_node->MutableOutputDefs()[0]->Name());
ASSERT_EQ(recompute_tile_node->Priority(), static_cast<int>(ExecutionPriority::LOCAL_LOW));
ASSERT_EQ(original_tile_node->Priority(), static_cast<int>(ExecutionPriority::DEFAULT));
ASSERT_EQ(query_layer_grad_node->Priority(), static_cast<int>(ExecutionPriority::DEFAULT));
}
} // namespace test
} // namespace onnxruntime