From a3e7da60e7a12a3fd0f29bc180840f0bcd7f6a85 Mon Sep 17 00:00:00 2001 From: pengwa Date: Thu, 3 Nov 2022 13:49:41 +0800 Subject: [PATCH] Trade subgraph recompute for memory (#12852) **Description**: Subgraph-level recompute This PR adds an optional capability trading additional re-computation for better memory efficiency. Specifically, a pre-defined operator list used to iterate the Graph to find some subgraphs for recompute, to reduce some stashed activations whose lifetime across forward and backward pass. When training with ORTModule, by default, the graph transformer will scan the execution graph to find all eligible subgraph to recompute, along with sizes that can save. An example looks like below. If we want to enable some of them to recompute, we can define env variable this way: `export ORTMODULE_ENABLE_MEMORY_ALLEVIATION="Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1:-1,BiasGelu+:1:-1,BitmaskDropout+Cast+:1:-1,FusedMatMul+:1:-1,Cast+:1:-1,Mul+Add+:1:-1,Mul+Sub+:1:-1"` ``` [1,0]:2,022-10-12 14:47:39.302,954,530 [W:onnxruntime:, memory_alleviation.cc:595 PrintSummary] [1,0]:MemoryAlleviation Summary: [1,0]: User config: [1,0]: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+:1,BiasGelu+:1,BitmaskDropout+Cast+:1,FusedMatMul+:1,Cast+:1,Mul+Add+:1,Mul+Sub+:1 [1,0]: ================================= [1,0]: Subgraph: BitmaskDropout+ [1,0]: AlleviationType: Disabled [1,0]: Patterns: [1,0]: PatternShape:input_ids_dim0 x 1,024 x Frequency:1 [1,0]: -------------------------------- [1,0]: Subgraph: BiasGelu+ [1,0]: AlleviationType: Recompute [1,0]: Patterns: [1,0]: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24 [1,0]: -------------------------------- [1,0]: Subgraph: Reshape[1,0]:+ [1,0]: AlleviationType: Disabled [1,0]: Patterns: [1,0]: PatternShape:labels_dim0 x Frequency:1 [1,0]: -------------------------------- [1,0]: Subgraph: Unsqueeze+Unsqueeze+Cast+Sub+Mul+Mul+FusedMatMul+Cast+Add+BiasSoftmaxDropout+Cast+ [1,0]: AlleviationType: Disabled [1,0]: Patterns: [1,0]: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 [1,0]: -------------------------------- [1,0]: Subgraph: Mul+FusedMatMul+Cast+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Add+BiasSoftmaxDropout+Cast+ [1,0]: AlleviationType: Recompute [1,0]: Patterns: [1,0]: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:1 [1,0]: -------------------------------- [1,0]: Subgraph: Mul+Add+ [1,0]: AlleviationType: Recompute [1,0]: Patterns: [1,0]: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24 [1,0]: -------------------------------- [1,0]: Subgraph: FusedMatMul+Cast+Add+Reshape+Cast+ [1,0]: AlleviationType: Disabled [1,0]: Patterns: [1,0]: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 2 x 4 x Frequency:24 [1,0]: -------------------------------- [1,0]: Subgraph: Mul+Sub+ [1,0]: AlleviationType: Recompute [1,0]: Patterns: [1,0]: PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x 1 x Frequency:24 [1,0]: -------------------------------- [1,0]: Subgraph: Cast+ [1,0]: AlleviationType: Recompute [1,0]: Patterns: [1,0]: PatternShape:1,024 x 1,024 x Frequency:97 [1,0]: PatternShape:3 x 1,024 x Frequency:1 [1,0]: PatternShape:8 x 64 x Frequency:24 [1,0]: PatternShape:1,024 x 4,096 x Frequency:24 [1,0]: PatternShape:4,096 x Frequency:24 [1,0]: PatternShape:4,096 x 1,024 x Frequency:24 [1,0]: -------------------------------- [1,0]: Subgraph: FusedMatMul+ [1,0]: AlleviationType: Recompute [1,0]: Patterns: [1,0]: PatternShape:input_ids_dim0 x input_ids_dim1 x 4,096 x Frequency:24 [1,0]: -------------------------------- [1,0]: ================================= ``` "Type config:" whether recompute is enabled by users. 0 - disable, 1- enable. "Subgraph" means what kind of subgraph will be recomputed, in this case, it is a single node "Gelu", and it will be "Recompute". "Shape && Frequency" means, for this recompute, one tensor of size (batch size, 500) will be saved because it will be recomputed. **Baseline** On a 1P model (DEBERTA V2), sequence length 256, training with 16 A100 GPUs. With latest main branch, we can run batch size 16, and the maximum batch size < 32. So 16 is usually chosen by data scientists. 65% of 40GB memory is used during training. The SamplesPerSec=479.2543353561354. ![image](https://user-images.githubusercontent.com/10530022/188320941-13dde5e7-c32b-4399-a64b-6803fbb9dcda.png) **With this PR** Gelu is recomputed for saving memory peak, batch size 32 can be run. The 97% of 40GB A100 is used, the SamplesPerSec=562.041593991271 (**1.17X** of baseline). ![image](https://user-images.githubusercontent.com/10530022/188321081-f64811bf-9637-4873-8095-349de8d498cc.png) **Motivation and Context** - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --- docs/Memory_Optimizer.md | 90 ++ .../onnxruntime_session_options_config_keys.h | 22 +- onnxruntime/core/framework/memory_info.cc | 4 +- .../core/optimizer/graph_transformer_utils.cc | 15 + .../transform/recompute/recompute_gelu.onnx | Bin 0 -> 2644 bytes .../recompute_test_graph_generator.py | 88 ++ .../transform/recompute/recompute_tile.onnx | Bin 0 -> 4413 bytes .../core/optimizer/memory_optimizer.cc | 785 ++++++++++++++++++ .../core/optimizer/memory_optimizer.h | 334 ++++++++ .../ortmodule/_graph_execution_manager.py | 5 + .../test/optimizer/memory_optimizer_test.cc | 147 ++++ 11 files changed, 1485 insertions(+), 5 deletions(-) create mode 100644 docs/Memory_Optimizer.md create mode 100644 onnxruntime/test/testdata/transform/recompute/recompute_gelu.onnx create mode 100644 onnxruntime/test/testdata/transform/recompute/recompute_test_graph_generator.py create mode 100644 onnxruntime/test/testdata/transform/recompute/recompute_tile.onnx create mode 100644 orttraining/orttraining/core/optimizer/memory_optimizer.cc create mode 100644 orttraining/orttraining/core/optimizer/memory_optimizer.h create mode 100644 orttraining/orttraining/test/optimizer/memory_optimizer_test.cc diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md new file mode 100644 index 0000000000..92ddac6508 --- /dev/null +++ b/docs/Memory_Optimizer.md @@ -0,0 +1,90 @@ +# Memory Optimizer for ONNX Runtime Training + +## Introduction + +ONNX Runtime Training provides a capability trading node/subgraph recomputations for better memory efficiency. +Specifically, a list of recomputable operators is pre-defined, with which memory optimizer graph transformer will iterate the graph to find all recomputable subgraph candidates. + +When training with ORTModule, by default, the graph transformer will scan the execution graph to find all eligible subgraphs to recompute, along with sizes that can save. Users can pick up some of the subgraphs to enable them by environment variables. + +## When memory optimizer can help? + +Classical scenarios include: + +- ORTModule run a model with batch size B (for example 2^N), the memory bandwidth and compute are not fully saturated, while it hits OOM to run a bigger batch size (for example 2^(N+1)). + +- For big models, ORTModule fails to run the minimum allowed batch size, so performance can be compromised for a successful run. + +Not all models and recipes need this optimizer technique. Imagine if your training recipe is using a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6. + +## Quick trial + +1. Make sure ONNX Runtime training wheel is installed and correctly configured. +2. Integrate models using ORTModule, be noted log_level should be equal or lower than INFO. + > ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.INFO)) +3. Run the training as usual and redirect all outputs into log file; then stop it after training few steps. +4. Check the logging file, search "Summary", you could possibly find something like this: + ``` + MemoryOptimizer Summary: + User config: + + ================================= + ########Recompute######## + Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ + OptimizationType: Disabled + Patterns: + PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 + -------------------------------- + Subgraph: FastGelu+ + OptimizationType: Disabled + Patterns: + PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 + ================================= + ########RecomputeWithCompromise######## + Subgraph: Cast+Where+Softmax+ + OptimizationType: Disabled + Patterns: + PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 + -------------------------------- + ================================= + ``` +5. As shown above, 'Subgraph' shows 1) a string representative for a recomputable subgraph; and 2) current status of memory optimization. All are disabled for recompute in this case. +6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In this sample, 12 FastGelu related subgraphs are allowed to recompute. +`FastGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `12` means the initial 12 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. + ``` + export ORTMODULE_MEMORY_OPT_CONFIG="FastGelu+:1:12" + ``` +7. Then run the training again, you will see logs like this: + ``` + MemoryOptimizer Summary: + User config: + **FastGelu+:1:12** + ================================= + ########Recompute######## + Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ + OptimizationType: Disabled + Patterns: + PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 + -------------------------------- + Subgraph: FastGelu+ + OptimizationType: **Recompute (requested_count=12, actual applied_count=12)** + Patterns: + PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 + ================================= + ########RecomputeWithCompromise######## + Subgraph: Cast+Where+Softmax+ + OptimizationType: Disabled + Patterns: + PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 + -------------------------------- + ================================= + ``` +8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well. + +## Compromised Recompute + +If you check the above logs, there is a separate section called "RecomputeWithCompromise". Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it. + +## Notes + +The feature is in experimental stage, we will tune and refine it according to real use cases. diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index d97f3608cd..5951757106 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -61,6 +61,22 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; +#ifdef ENABLE_TRAINING +// Specifies a list of op types for memory footprint reduction. +// The value should be a ","-delimited list of pair of +// . +// For example, "Gelu+Cast+:1:0,Dropout+:1:1". +// A valid "subgraph string" should be one subgraph representation output by ORT graph transformations. +// "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute. +// "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving" +// the memory. +static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.enable_memory_optimizer"; + +// Specifies the level for detecting subgraphs for memory footprint reduction. +// The value should be an integer. The default value is 0. +static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level"; +#endif + // Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0". // Using device allocators means the memory allocation is made using malloc/new. static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers"; @@ -81,9 +97,9 @@ static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "ses /// /// Key for using the ORT format model flatbuffer bytes directly for initializers. -/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization. -/// Requires `session.use_ort_model_bytes_directly` to be true. -/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire +/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization. +/// Requires `session.use_ort_model_bytes_directly` to be true. +/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire /// duration of the InferenceSession. /// static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers = diff --git a/onnxruntime/core/framework/memory_info.cc b/onnxruntime/core/framework/memory_info.cc index ebdef0e2fe..cdd0818ddf 100644 --- a/onnxruntime/core/framework/memory_info.cc +++ b/onnxruntime/core/framework/memory_info.cc @@ -112,7 +112,7 @@ void MemoryInfo::RecordActivationAllocInfo(const OrtValueIndex idx, const OrtVal else if (map[MapType::Initializer].Contain(reuse_buffer)) map_type = MapType::Initializer; else - std::cout << "Find no map type for reuse_buffer: " << reuse_buffer << ", so skipping" << std::endl; + LOGS_DEFAULT(VERBOSE) << "Find no map type for reuse_buffer: " << reuse_buffer << ", so skipping"; RecordTensorDeviceAllocInfo(idx, value, map_type); } @@ -365,7 +365,7 @@ void MemoryProfiler::CreateEvents(const std::string& p_name, void MemoryProfiler::GenerateMemoryProfile() { // Write memory profile .json std::stringstream ss; - ss << "memory_profile_" << GetMemoryInfo().GetLocalRank() << "_" << profiler_id_ << ".json"; + ss << "memory_profile_" << GetMemoryInfo().GetLocalRank() << "_" << Env::Default().GetSelfPid() << "_" << profiler_id_ << ".json"; std::ofstream memory_profile(ss.str(), std::ios::trunc); memory_profile << "[" << std::endl; for (size_t i = 0; i < GetEvents().size(); i++) { diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 21be292a61..3093d8d43b 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -68,6 +68,7 @@ #ifdef ENABLE_TRAINING #include "orttraining/core/optimizer/bitmask_dropout_replacement.h" #include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h" +#include "orttraining/core/optimizer/memory_optimizer.h" #include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h" #endif @@ -297,6 +298,19 @@ InlinedVector> GenerateTransformers( // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their // fusions might be prevented if this one removes a Q/DQ node too early. transformers.emplace_back(std::make_unique(enable_quant_qdq_cleanup)); + +#ifdef ENABLE_TRAINING + // Put memory optimization transformer at last (which is done after most of fusions are done) by intention. + // Known issue: after mmeory optimization is completed, if some fusion happens, it is possible that the + // node priority got changed. This may disorder the execution order of nodes to recompute. + // TODO(pengwa): need to fix this issue. + const std::string enable_memory_optimizer = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerEnabler, ""); + const std::string probe_level = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerProbeLevel, "0"); + transformers.emplace_back(std::make_unique(enable_memory_optimizer, probe_level)); +#endif + } break; case TransformerLevel::Level3: { @@ -315,6 +329,7 @@ InlinedVector> GenerateTransformers( // while we can fuse more activation. transformers.emplace_back(std::make_unique(cpu_ep)); #endif + } break; default: diff --git a/onnxruntime/test/testdata/transform/recompute/recompute_gelu.onnx b/onnxruntime/test/testdata/transform/recompute/recompute_gelu.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7ac42056b949f02ff5e4f2b57a079b32c1b8ea9e GIT binary patch literal 2644 zcmb_eOHbQC5RM@v&d`8aLJ6TlH0n!HF^OrbDwh`bz!8a_N)Ic?-Xsow$&UtZJ@wpc zZyYN1)<4ld&>z!NPu;a+uw!FXiA$2r&dz+_H#0lS^Qh`v__nK#Ta}*Jlf-WSFM{tJ zRCLP;{2pv$+to+9r5M<>wZI_aQ0<9RqL0QtY~^kabj3p(s%?s##X8feP;r9_jr>JfR zK26467mr*;YgIq!#L$kNcKoBW1!x!;2KL5^LvUp6y`_>>B|bFXJ$B}a|KK3)w~#qn zOG+ByK%u@8KBA!ZQ+d5kV~K*`vU!dZnaL|lvs9NA^=zuR+ODH`9(>^-8~vu#w^Hei z>n;L(KPrS6un{oiv>18zKl?owd0w)g92B^ZTL$gsc16@ON#(jRkVHq!D|amvtN5FR zl*0y@!wyEwU}y~A#}tD&7-ZmaYAr&UoLYg9)`cv0c5(w<5OQvlfZ;7|A(E%h^AEQ{q5%?|kDY*@i$15gvyEcE z+;*KlVYsuSUC5AodM%y7A4=lHlsUIxq;8Hmj7{0iT~znD@1l@kn>xpAs9hAi2R^YJ z%Q@C_bIA`zLohc_EH-S?Q|6PVJ7GM)yGwm{+#Xwf-woU&2Dd6L@aG zQ(#-Gid@*z%i#e2Ql$7N-pg)5Vr>a3g0G?QG-@rw0Tl+K z@*_m%@{mL>`*;H%`(&V1f3T=Mc;M;Dpmi0{(ZsfK5GAy(07=*PpQLe>n(j4mf}10;h~&JJ#-~b8hn@g^1AC=$HHyeibR-& z6~>ZD`cF@YKWKU+2;dh5rb&f0N$Mo6WO7Oaq)9|$fHUSaBwAY4*vOQ4-jFl-u-)0( zWt+KnpL&$AK%1s|bB5|)^FBD22Q!X|)}>?iBUGaUJg~+AKG0SrY7A0cy&^^_)P$9n z_r>oD%#;luO0v0T4#%Q^#+EY)T!$h(&%z-$G49AFj8R9sEb%;~Ld2F{jfbr|)QD}n zC(o(xMv#%W^zR&Q$z_RF#3c}yov|aza6OrCi$u;q7w*Va_!1-)LEAY@MKG33rvqQ^ z6nS_#gF?H}6>*vr^f->kxRP1wi8GzAZ0NOg!r`IJFQxHbX>V?i8(Ym;0>AN8oT3F@ zA)-i53_p%23;T``f3fRL33~4}g8rwHOTZ#M(9R{-VA7|zIC9`4*T?Z0ML!goh89Kh zBs?mY@k1JJel!{59le%LxHHwKlzL&kiTi8~Rvg#$aCl>5I`-HZ4za2zb5Sko-|)y{ zt@V1Ll|)D(Lnp&2xJDDv0jq+xP>@643=3buZg`ZaqIyfHY8nu)ygmXp6a zv)R6$dl?%U>dTd6bTVXR)W~h#OcP{HE{kqv+x+MrsL5mHH67N;1=pObFPj$uh}8}+ z(DvA_&zmbyPv$Is*VYxY0V~lXHE3H7qrOAzwlh5G*Dsmt&`9__pFXqbNjtOe63CBR uF6w>Ir1x%{p7I{X8dHZ6o71nrdotbWhnkeoGS@$D$A6WydH?CumHz?5e?Wl% literal 0 HcmV?d00001 diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer.cc new file mode 100644 index 0000000000..3a6c1bef98 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer.cc @@ -0,0 +1,785 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/random_seed.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/utils.h" +#include "orttraining/core/graph/recompute_graph_utils.h" +#include "orttraining/core/optimizer/memory_optimizer.h" + +namespace onnxruntime { + +namespace { + +constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15; + +std::string TensorShapeProtoToString(const ONNX_NAMESPACE::TensorShapeProto* shape) { + std::ostringstream shape_oss; + if (shape != nullptr) { + for (int dim_index = 0; dim_index < shape->dim_size(); dim_index++) { + auto dim = shape->dim(dim_index); + if (utils::HasDimValue(dim)) { + shape_oss << dim.dim_value() << " x "; + } else { + shape_oss << dim.dim_param() << " x "; + } + } + } else { + shape_oss << "unknown"; + } + + return shape_oss.str(); +} + +int ParseIntValueFromString(std::string_view str) { + int int_value = 0; + auto result = std::from_chars(str.data(), str.data() + str.size(), int_value); + ORT_ENFORCE(result.ec != std::errc::invalid_argument, "Fail to convert to int from string: ", str); + return int_value; +} + +bool IsForwardPassOperator(int64_t op_order_in_topological_sort, int64_t boundary_op_order_in_topological_sort) { + return op_order_in_topological_sort <= boundary_op_order_in_topological_sort; +} + +static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) { + const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); + MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto); + const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); + ORT_ENFORCE(nullptr != tensor_type_base); + MLDataType elt_type = tensor_type_base->GetElementType(); + return elt_type->Size(); +} + +// TODO(pengwa): extend this function to be more general. +float InputOutputSizeRatio(const Node* node) { + if (node->OpType().compare("Cast") == 0) { + const NodeArg* input = node->InputDefs()[0]; + const NodeArg* output = node->OutputDefs()[0]; + if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING || + output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + return 1.0f; + } + const auto& ptype1 = input->Type(); + const auto& ptype2 = output->Type(); + float ratio = float(GetElementSize(ptype1)) / (float)GetElementSize(ptype2); + return ratio; + } + + return 1.0f; +} + +} // namespace + +Status MemoryOptimizer::ParseConfigFromString(const std::string& enable_memory_optimizer, + const std::string& level) { + optimizer_config_ = enable_memory_optimizer; + if (!enable_memory_optimizer.empty()) { + const auto user_config_strs = utils::SplitString(enable_memory_optimizer, ","); + for (const auto& user_config_str : user_config_strs) { + const auto user_config = utils::SplitString(user_config_str, ":"); + ORT_RETURN_IF_NOT(user_config.size() == 3, + "User config should be in format of SubgraphStr:OptimizationType:RequestApplyCount."); + + const std::string subgraph_string_representation(user_config[0]); + int optimization_type_int = ParseIntValueFromString(user_config[1]); + int requested_apply_count = ParseIntValueFromString(user_config[2]); + ORT_RETURN_IF_NOT(optimization_type_int < static_cast(OptimizationType::TypeMax) && + optimization_type_int >= 0, + "Invalid optimization type specified for subgraph: ", + subgraph_string_representation); + + ORT_RETURN_IF_NOT(requested_apply_count == -1 || requested_apply_count >= 0, + "Invalid requested_apply_count specified for subgraph: ", requested_apply_count); + + // At this point, subgraph_string_representation is a pattern graph string representation. + pattern_subgraph_to_user_optimizer_config_map_[subgraph_string_representation] = + UserConfig{static_cast(optimization_type_int), requested_apply_count}; + } + } + + int probe_level = ParseIntValueFromString(level); + ORT_RETURN_IF_NOT(probe_level < static_cast(ProbeLevel::LevelMax) && probe_level >= 0, + "Invalid probe level specified: ", level); + recompute_probe_level_ = static_cast(probe_level); + + return Status::OK(); +} + +int64_t MemoryOptimizer::PrepareForTransformation(const Graph& graph, + ActivationUsedMap& fw_op_output_arg_used_map, + InlinedHashMap& + node_index_to_its_order_in_topological_sort_map) const { + fw_op_output_arg_used_map.clear(); + + GraphViewer graph_viewer(graph); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); + + // Find boundary ops between forward and backward pass, currently, it's limited to YieldOp. + int64_t yield_op_order_in_topological_sort = -1; + for (size_t i = 0; i < node_ids.size(); ++i) { + const Node* p_node = graph.GetNode(node_ids[i]); + if (p_node == nullptr) { /* skip removed nodes*/ + continue; + } + + if (p_node->OpType() == "YieldOp") { + yield_op_order_in_topological_sort = static_cast(i); + } + + node_index_to_its_order_in_topological_sort_map[p_node->Index()] = i; + } + + // If boundary op found, create forward op output arg used map. + if (yield_op_order_in_topological_sort >= 0) { + for (size_t i = 0; i < node_ids.size(); ++i) { + const Node* p_node = graph.GetNode(node_ids[i]); + if (p_node == nullptr /* skip removed nodes*/) { + continue; + } + + const Node& node = *p_node; + bool is_forward_op = IsForwardPassOperator(static_cast(i), yield_op_order_in_topological_sort); + if (!is_forward_op) { + continue; + } + + for (auto& output_arg : node.OutputDefs()) { + bool used_in_fw = false; + bool used_in_bw = false; + for (auto& consumer_node : graph.GetConsumerNodes(output_arg->Name())) { + auto consumer_node_index_in_topological_order = + node_index_to_its_order_in_topological_sort_map.at(consumer_node->Index()); + if (IsForwardPassOperator(static_cast(consumer_node_index_in_topological_order), + yield_op_order_in_topological_sort)) { + used_in_fw = true; + } else { + used_in_bw = true; + } + } + fw_op_output_arg_used_map.insert({{output_arg->Name(), std::make_pair(used_in_fw, used_in_bw)}}); + } + } + } + + // Return whether boundary op is found or not. + return yield_op_order_in_topological_sort; +} + +Status MemoryOptimizer::GetStashedActivationCandidates(const Graph& graph, + const InlinedHashMap>& + fw_op_output_arg_used_map, + InlinedHashMap>& + candidate_output_args_map, + const logging::Logger& logger) const { + for (auto& kv : fw_op_output_arg_used_map) { + // used by fw and bw, then it is a candidates. + if (kv.second.first && kv.second.second) { + const Node* n = graph.GetProducerNode(kv.first); + ORT_ENFORCE(n, "Activation should have a producer node"); + size_t k = 0; + for (k = 0; k < n->OutputDefs().size(); ++k) { + if (n->OutputDefs()[k]->Name().compare(kv.first) == 0) { + break; + } + } + + candidate_output_args_map[n].push_back(k); + LOGS(logger, VERBOSE) << "Find candidate output named [" << kv.first << "] of Node " << n->Name() << "(" + << n->OpType() << ")"; + } + } + + return Status::OK(); +} + +bool MemoryOptimizer::ModifyGraph(Graph& graph, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const InlinedHashMap>& + candidate_output_args_map, + const logging::Logger& logger, + int64_t boundary_op_order_in_topological_sort, + SubGraphStores& subgraph_stores, + Node* node) const { + bool graph_is_modified = false; + if (subgraph_stores.SubGraphDescCount() == 0) { + return graph_is_modified; + } + + SubGraphStores::GraphInstanceInfo& sub_graph_instance_info = + subgraph_stores.GetSubGraphInstance(node); + + SubGraphDesc& subgraph_desc = subgraph_stores.GetSubGraphDesc(sub_graph_instance_info.second); + UserConfig user_config = subgraph_desc.user_optimizer_config; + int skip_count = (user_config.requested_count == -1) + ? 0 + : std::max(0, subgraph_desc.total_frequency - user_config.requested_count); + + subgraph_desc.skip_count += 1; + + if (user_config.type != OptimizationType::None && subgraph_desc.skip_count > skip_count) { + subgraph_desc.applied_count += 1; + Node* replacement_node_ptr = nullptr; + LOGS(logger, WARNING) << "[Modify Graph] Node " << node->Name() << "(" << node->OpType() << ") is " + << UserConfigToString(user_config); + if (user_config.type == OptimizationType::Recompute) { + ORT_ENFORCE(CreateRecomputeGraph(graph, sub_graph_instance_info.first, replacement_node_ptr).IsOK()); + } else { + ORT_THROW("unsupported optimization type found: " + UserConfigToString(user_config)); + } + ORT_ENFORCE(replacement_node_ptr); + + graph_is_modified = true; + + for (size_t output_index : candidate_output_args_map.at(node)) { + // Collect output edges (connecting to backward ops), to remove. + std::vector output_edges; + for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { + size_t src_output_idx = static_cast(it->GetSrcArgIndex()); + if (src_output_idx != output_index) { + continue; + } + + auto tid = node_index_to_its_order_in_topological_sort_map.find(it->GetNode().Index()); + // It is possible the consumer node is newly added as the recompute node, so we need a check here. + // For those kind of ops, we can treat them as backward ops. + if (tid == node_index_to_its_order_in_topological_sort_map.end() || + !IsForwardPassOperator(node_index_to_its_order_in_topological_sort_map.at(tid->first), + boundary_op_order_in_topological_sort)) { + // Remove the edge only connecting to backward op. + output_edges.push_back(graph_utils::GraphEdge::CreateGraphEdge(*node, *it, false)); + } + } + + if (!output_edges.empty()) { + // Remove the output edges of the node first + graph_utils::GraphEdge::RemoveGraphEdges(graph, output_edges); + + // Create connections between the replacement node and the outgoing nodes. + for (const auto& output_edge : output_edges) { + graph.RemoveConsumerNode(node->MutableOutputDefs()[output_index]->Name(), node); + + // Add new edge connecting the input with the output nodes directly. + // This also updates the destination node's input node args + graph.AddEdge(replacement_node_ptr->Index(), output_edge.dst_node, static_cast(output_index), + output_edge.dst_arg_index); + } + } + } + } + + return graph_is_modified; +} + +Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) + const { + LOGS(logger, VERBOSE) << "Memory optimization config: " << optimizer_config_ << ", probe level: " + << static_cast(recompute_probe_level_); + + InlinedHashMap> fw_op_output_arg_used_map; + InlinedHashMap node_index_to_its_order_in_topological_sort_map; + int64_t boundary_op_order_in_topological_sort = + PrepareForTransformation(graph, fw_op_output_arg_used_map, + node_index_to_its_order_in_topological_sort_map); + if (boundary_op_order_in_topological_sort < 0) { + LOGS(logger, VERBOSE) << "No boundary op found. Skip memory optimization."; + return Status::OK(); + } + + InlinedHashMap> candidate_output_args_map; + ORT_RETURN_IF_ERROR(GetStashedActivationCandidates(graph, fw_op_output_arg_used_map, candidate_output_args_map, + logger)); + + SubGraphStores recompute_subgraph_stores; + SubGraphStores recompute_with_compromise_subgraph_stores; + GraphViewer graph_viewer(graph); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); + + // The first pass - find the candidate subgraphs. + for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { + Node* p_node = graph.GetNode(node_ids[i]); + if (p_node == nullptr) { + continue; + } + + if (candidate_output_args_map.find(p_node) == candidate_output_args_map.end()) { + continue; + } + + bool can_compromise_stashed_activation = false; + CheckNodeForRecompute(*p_node, fw_op_output_arg_used_map, + node_index_to_its_order_in_topological_sort_map, + candidate_output_args_map, + recompute_subgraph_stores, logger, false, + can_compromise_stashed_activation); + + if (can_compromise_stashed_activation) { + LOGS(logger, VERBOSE) << "Searching Node " << p_node->Name() << "(" << p_node->OpType() + << ") for compromised recompute"; + // If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist + // during backward pass, then we can try to compromise the assumption. + CheckNodeForRecompute(*p_node, fw_op_output_arg_used_map, node_index_to_its_order_in_topological_sort_map, + candidate_output_args_map, + recompute_with_compromise_subgraph_stores, logger, true, + can_compromise_stashed_activation); + } + } + + // The second pass - apply the transformation. + // Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. + // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended + // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier + // layers. + for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { + Node* p_node = graph.GetNode(node_ids[i]); + if (p_node == nullptr) { + continue; + } + + bool has_been_modified = false; + if (recompute_subgraph_stores.ContainsSubGraphInstance(p_node)) { + has_been_modified = ModifyGraph(graph, node_index_to_its_order_in_topological_sort_map, + candidate_output_args_map, logger, + boundary_op_order_in_topological_sort, + recompute_subgraph_stores, p_node); + } + + // If there are other recompute plan for this node, we skip them because the graph is already modified. + if (!has_been_modified && recompute_with_compromise_subgraph_stores.ContainsSubGraphInstance(p_node)) { + has_been_modified = ModifyGraph(graph, node_index_to_its_order_in_topological_sort_map, + candidate_output_args_map, logger, + boundary_op_order_in_topological_sort, + recompute_with_compromise_subgraph_stores, p_node); + } + + modified = modified || has_been_modified; + } + + PrintSummary(recompute_subgraph_stores, recompute_with_compromise_subgraph_stores, logger); + + return Status::OK(); +} + +void MemoryOptimizer::NodesInTopoOrderToString(const InlinedVector& nodes_in_topological_order, + std::string& subgraph_string_representation, + std::string& log_info) const { + std::ostringstream oss; + std::ostringstream subgraph_string_representation_oss; + size_t node_count = nodes_in_topological_order.size(); + for (size_t i = 0; i < node_count; ++i) { + if (i < node_count - 1) { // Ignore the last node. + oss << "(name:" << nodes_in_topological_order[i]->Name() << ", type:" << nodes_in_topological_order[i]->OpType() + << "),"; + } + + subgraph_string_representation_oss << nodes_in_topological_order[i]->OpType() << "+"; + } + + subgraph_string_representation = subgraph_string_representation_oss.str(); + log_info = oss.str(); + if (log_info.size() > 0) { + log_info = " with its precedent nodes: " + log_info; + } +} + +std::string MemoryOptimizer::UserConfigToString(const UserConfig& config) const { + std::string type_str; + switch (config.type) { + case OptimizationType::None: { + type_str = "Disabled"; + } break; + case OptimizationType::Recompute: { + type_str = "Recomputed"; + } break; + default: { + type_str = "Unknown"; + } break; + } + return type_str; +} + +void MemoryOptimizer::PrintSummary(const SubGraphStores& recompute_stores, + const SubGraphStores& recompute_with_compromise_stores, + const logging::Logger& logger) const { + if (recompute_stores.SubGraphDescCount() == 0 && recompute_with_compromise_stores.SubGraphDescCount() == 0) { + return; + } + + std::ostringstream summary; + summary << "\nMemoryOptimizer Summary:\n"; + summary << "\tUser config:\n\t" << optimizer_config_ << "\n"; + summary << "\t=================================\n"; + + auto print_info_from_stores = [&summary, this](std::string store_name, const SubGraphStores& stores) { + summary << "\t########" << store_name << "########\n"; + for (auto subgraph_it = stores.subgraph_descs.begin(); subgraph_it != stores.subgraph_descs.end(); + ++subgraph_it) { + std::string freq_info; + if (subgraph_it->second.user_optimizer_config.type != OptimizationType::None) + freq_info = " (requested_count=" + std::to_string(subgraph_it->second.user_optimizer_config.requested_count) + + ", actual applied_count=" + + std::to_string(subgraph_it->second.applied_count) + ")"; + summary << "\tSubgraph: " << subgraph_it->first << "\n" + << "\t\tOptimizationType: " + << UserConfigToString(subgraph_it->second.user_optimizer_config) << freq_info << "\n" + << "\t\tPatterns: \n"; + for (auto shape_stat_it = subgraph_it->second.shape_str_frequency.begin(); + shape_stat_it != subgraph_it->second.shape_str_frequency.end(); + ++shape_stat_it) { + summary << "\t\t\tPatternShape:" << shape_stat_it->first << "\tFrequency:" << shape_stat_it->second << "\n"; + } + summary << "\t--------------------------------\n"; + } + summary << "\t=================================\n"; + }; + + print_info_from_stores("Recompute", recompute_stores); + print_info_from_stores("RecomputeWithCompromise", recompute_with_compromise_stores); + + LOGS(logger, INFO) << summary.str() << "\n"; +} + +/****************************************************** + ** Recompute related function implementation starts ** + ******************************************************/ + +void MemoryOptimizer::RegisterAllowedRecomputeOps() { + if (static_cast(recompute_probe_level_) >= static_cast(ProbeLevel::Basic)) { + recomputable_op_type_to_input_arg_index_map_.insert({ + // Binary elementwise + {"Add", AllowedRecomputeNodeConfig{{0, 1}}}, + {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Div", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Mul", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Sub", AllowedRecomputeNodeConfig{{0, 1}}}, + + // Data layout + /// The shape input is trivial whether it exists or not in backward. + {"Reshape", AllowedRecomputeNodeConfig{{0}}}, + {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, + {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}}, + + // Unary elementwise + /// The ratio and mode input are trivial whether they exist or not in backward + {"BitmaskDropout", AllowedRecomputeNodeConfig{{0}}}, + /// The axis input is trivial whether it exists or not in backward + {"CumSum", AllowedRecomputeNodeConfig{{0}}}, + {"Dropout", AllowedRecomputeNodeConfig{{0}}}, + {"Gelu", AllowedRecomputeNodeConfig{{0}}}, + {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, + + // Ternary elementwise + {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, + + // Data copy + {"Tile", AllowedRecomputeNodeConfig{{0}}}, + {"Cast", AllowedRecomputeNodeConfig{{0}}}, + }); + } + + if (static_cast(recompute_probe_level_) >= static_cast(ProbeLevel::Advanced)) { + recomputable_op_type_to_input_arg_index_map_.insert({ + {"MatMul", AllowedRecomputeNodeConfig{{0, 1}}}, + {"FusedMatMul", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Softmax", AllowedRecomputeNodeConfig{{0}}}, + {"BiasSoftmax", AllowedRecomputeNodeConfig{{0, 1}}}, + {"BiasSoftmaxDropout", AllowedRecomputeNodeConfig{{0, 1}}}, + }); + } +} + +Status MemoryOptimizer::SelectRecomputeSubgraph(const Node& node, + const InlinedVector& node_output_index_candidates, + const ActivationUsedMap& fw_op_output_arg_used_map, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + InlinedVector& nodes, + const logging::Logger& logger, + bool compromise_stashed_activation, + bool& can_compromise_stashed_activation) const { + can_compromise_stashed_activation = false; + + LOGS(logger, VERBOSE) << "Enter SelectRecomputeSubgraph for Node " << node.Name() << "(" << node.OpType() << ")"; + nodes.clear(); + + std::deque q; + for (auto output_index : node_output_index_candidates) { + q.push_back(NodeOutputPort(&node, static_cast(output_index))); + } + + bool early_stop = false; + std::set visited_output_arg_set; + std::set visited_node_set; + + // For the initial activations in queue, they are stashed ones, so we do differently when scan the queue for them. + bool is_first_queue_scan = true; + while (nodes.size() < MAXIMUM_RECOMPUTE_NODE_COUNT && !q.empty() && !early_stop) { + // Loop all candidate NodeOutputPort, and find the next layer of input nodes. + size_t current_queue_size = q.size(); + for (size_t i = 0; i < current_queue_size; ++i) { + NodeOutputPort p = q.front(); + q.pop_front(); + const Node* curr_node = p.first; + + // Skip if the node output is already visited. + if (std::find(visited_output_arg_set.begin(), visited_output_arg_set.end(), p) != + visited_output_arg_set.end()) { + continue; + } + + visited_output_arg_set.insert({p}); + + // If the node already visited by from it's other output index, skip it. + if (visited_node_set.find(curr_node) != visited_node_set.end()) { + continue; + } + + visited_node_set.insert(curr_node); + + // Bottom-up search rules. + // If current op is entry output node (that generates stashed activations): + // 1. If the op is not in recomputable_op_type_to_input_arg_index_map_, skip it. + // Otherwise: + // If current op is in allowed list, check its input args, and append the producers' NodeOutputPorts to next_q. + // If current op is NOT in allowed list: + // 1). the output does not exist in backward, we cannot find a good solution for so, search terminates. + // 2). the output is used in backward, we don't need trace back further, continue searching. + auto op_recompute_config_it = recomputable_op_type_to_input_arg_index_map_.find(curr_node->OpType()); + auto cur_output_arg_name = curr_node->OutputDefs()[p.second]->Name(); + if (is_first_queue_scan) { + // We handle the entry node outputs differently because, we don't want this case falls into and succeed one of + // the checks in the other branch + // 1. "op is not in recompute op list, but its output is used in backward" + // 2. "op is in recompute op list, but its output is used in backward" + // (either of the above checks is true for entry node outputs) + if (op_recompute_config_it == recomputable_op_type_to_input_arg_index_map_.end()) { + early_stop = true; + LOGS(logger, VERBOSE) << "Entry Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** " + << "in recompute op list, search terminates."; + break; + } + } else { + if (op_recompute_config_it == recomputable_op_type_to_input_arg_index_map_.end()) { + if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in " + << "recompute op list, but its output [" << cur_output_arg_name + << "] is used in backward, we don't need trace bottom-up further"; + continue; + } else { + early_stop = true; + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in " + << "recompute op list, and its output [" << cur_output_arg_name + << "] does not exist in backward, search terminates."; + break; + } + } + + if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") " + << "is in recompute op list, while its output [" << cur_output_arg_name + << "] is used in backward, we don't need trace bottom-up further"; + continue; + } + } + + // Append node to the selected graph. + if (std::find(nodes.begin(), nodes.end(), curr_node) == nodes.end()) { + nodes.push_back(curr_node); + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() + << ") is added in selected subgraph "; + } + + // This check is not matured now, subject to be changed. + float ratio = InputOutputSizeRatio(curr_node); + float is_current_node_compromisable = (ratio < 1.f); + can_compromise_stashed_activation = can_compromise_stashed_activation || is_current_node_compromisable; + if (is_current_node_compromisable) { + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() + << ") has input/output size " << ratio << " < 1.f, can compromise stashed activation"; + } + + if (is_current_node_compromisable && compromise_stashed_activation) { + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is in " + << "recompute op list, and its output [" << cur_output_arg_name + << "] does not exist in backward, while it meet compromised check, we don't need trace " + << "bottom-up further."; + continue; + } + + // Iterate all input nodes according to allowed input arg index of the entry node. + const auto& input_arg_indices = op_recompute_config_it->second.input_arg_indices; + for (auto it = curr_node->InputEdgesBegin(), end = curr_node->InputEdgesEnd(); it != end; ++it) { + const Node::EdgeEnd& input_edge = *it; + const auto& parent_node = input_edge.GetNode(); + const auto parent_node_output_index = input_edge.GetSrcArgIndex(); + const auto current_node_input_index = input_edge.GetDstArgIndex(); + if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != + input_arg_indices.end()) { + NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index); + + LOGS(logger, VERBOSE) << "Node " << parent_node.Name() << "(" << parent_node.OpType() << ")'s " + << parent_node_output_index + << "th output [" << parent_node.OutputDefs()[parent_node_output_index]->Name() + << "] is added in recompute search list "; + + q.push_back(next_p); + } + } + } + // After handle all entry node outputs, we set the flag to false. + is_first_queue_scan = false; + } + + // If input args are not found in bw, but op count exceed MAXIMUM_RECOMPUTE_NODE_COUNT, skip recompute. + if (!q.empty() || early_stop) { + LOGS(logger, VERBOSE) << "Fail to find a solution for recompute: current node count is " << nodes.size() + << ", queue size: " << q.size() << ", early stop: " << early_stop; + nodes.clear(); + } else { + // Re-order the nodes in topological order. + std::sort(nodes.begin(), nodes.end(), + [&node_index_to_its_order_in_topological_sort_map](const Node*& lhs, const Node*& rhs) { + return node_index_to_its_order_in_topological_sort_map.at(lhs->Index()) < + node_index_to_its_order_in_topological_sort_map.at(rhs->Index()); + }); + } + return Status::OK(); +} + +void MemoryOptimizer::CheckNodeForRecompute(const Node& node, + const ActivationUsedMap& fw_op_output_arg_used_map, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const InlinedHashMap>& + candidate_output_args_map, + SubGraphStores& subgraph_stores, + const logging::Logger& logger, + bool compromise_stashed_activation, + bool& can_compromise_stashed_activation) const { + if (recomputable_op_type_to_input_arg_index_map_.find(node.OpType()) == + recomputable_op_type_to_input_arg_index_map_.end()) { + return; + } + + InlinedVector nodes_in_topological_order; + ORT_ENFORCE(SelectRecomputeSubgraph(node, candidate_output_args_map.at(&node), + fw_op_output_arg_used_map, + node_index_to_its_order_in_topological_sort_map, + nodes_in_topological_order, logger, + compromise_stashed_activation, + can_compromise_stashed_activation) + .IsOK()); + if (nodes_in_topological_order.size() == 0) { + return; + } + + std::string subgraph_str_representation, log_info; + NodesInTopoOrderToString(nodes_in_topological_order, subgraph_str_representation, log_info); + LOGS(logger, VERBOSE) << "Node " << node.Name() << "(" << node.OpType() << ") can be recomputed" << log_info; + + // Update the subgraph optimization config map - key is the subgraph string representation, value is user config. + UserConfig user_config{OptimizationType::None, 0}; + if (pattern_subgraph_to_user_optimizer_config_map_.find(subgraph_str_representation) != + pattern_subgraph_to_user_optimizer_config_map_.end()) { + user_config = pattern_subgraph_to_user_optimizer_config_map_.at(subgraph_str_representation); + } + + SubGraphDesc& subgraph_desc = + subgraph_stores.Contains(subgraph_str_representation) + ? subgraph_stores.GetSubGraphDesc(subgraph_str_representation) + : subgraph_stores.CreateSubGraphDesc(subgraph_str_representation, user_config); + + subgraph_desc.total_frequency += 1; + + // Update the subgraph frequency map - key is the subgraph string representation, value is number of appearances. + for (size_t output_index : candidate_output_args_map.at(&node)) { + auto shape_str = TensorShapeProtoToString(node.OutputDefs()[output_index]->Shape()); + subgraph_desc.shape_str_frequency[shape_str]++; + } + + subgraph_stores.AddSubGraphInstance(&node, nodes_in_topological_order, subgraph_desc); + + return; +} + +Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, + const InlinedVector& nodes_in_topological_order, + Node*& new_output_node_ptr) const { + InlinedHashMap self_contained_outputs_map; + for (size_t i = 0; i < nodes_in_topological_order.size(); ++i) { + Node* node_to_duplicate = graph.GetNode(nodes_in_topological_order[i]->Index()); + + // Check whether the node has been recomputed/offloaded or not. Simply check the existence of the first output + // of the node has its corresponding recompute name or not. + // TODO: if there is more optimization types like offload added, we will add corresponding check whether the outputs + // already be offloaded or not. + if (graph.GetNodeArg(graph_utils::RecomputeName(node_to_duplicate->MutableOutputDefs()[0]->Name())) != nullptr) { + continue; + } + + InlinedVector new_input_args; + new_input_args.reserve(node_to_duplicate->MutableInputDefs().size()); + for (NodeArg* input_arg : node_to_duplicate->MutableInputDefs()) { + if (self_contained_outputs_map.find(input_arg) == self_contained_outputs_map.end()) { + NodeArg* recompute_input_arg = graph.GetNodeArg(graph_utils::RecomputeName(input_arg->Name())); + new_input_args.push_back(recompute_input_arg ? recompute_input_arg : input_arg); + } else { + new_input_args.push_back(self_contained_outputs_map[input_arg]); + } + } + + InlinedVector new_output_args; + new_output_args.reserve(node_to_duplicate->MutableOutputDefs().size()); + for (size_t k = 0; k < node_to_duplicate->MutableOutputDefs().size(); ++k) { + const auto& output = node_to_duplicate->MutableOutputDefs()[k]; + new_output_args.push_back(&graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()), + output->TypeAsProto())); + + self_contained_outputs_map[output] = new_output_args.back(); + } + + Node& recompute_node = graph.AddNode(node_to_duplicate->Name() + "_recompute", + node_to_duplicate->OpType(), + "Recompute of " + node_to_duplicate->Name(), + new_input_args, + new_output_args, + &node_to_duplicate->GetAttributes(), + node_to_duplicate->Domain()); + + recompute_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); + recompute_node.SetExecutionProviderType(node_to_duplicate->GetExecutionProviderType()); + ORT_RETURN_IF_NOT(graph.SetOpSchemaFromRegistryForNode(recompute_node), + "Failed to set op schema for added recompute node."); + + new_output_node_ptr = &recompute_node; + + for (size_t j = 0; j < recompute_node.MutableOutputDefs().size(); ++j) { + graph.UpdateProducerNode(recompute_node.MutableOutputDefs()[j]->Name(), recompute_node.Index()); + } + + // Add the edges from the recompute node to the original node. + for (size_t j = 0; j < recompute_node.MutableInputDefs().size(); ++j) { + NodeArg* input_arg = recompute_node.MutableInputDefs()[j]; + const Node* producer_node = graph.GetProducerNode(input_arg->Name()); + if (producer_node == nullptr) { + // Skip when it is graph input or initializer. + continue; + } + int producer_output_index = optimizer_utils::IndexOfNodeOutput(*producer_node, *input_arg); + graph.AddEdge(producer_node->Index(), recompute_node.Index(), static_cast(producer_output_index), + static_cast(j)); + + graph.AddConsumerNode(input_arg->Name(), &recompute_node); + } + } + + return Status::OK(); +} + +/****************************************************** + ** Recompute related function implementation ends ** + ******************************************************/ + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer.h b/orttraining/orttraining/core/optimizer/memory_optimizer.h new file mode 100644 index 0000000000..50efd50025 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer.h @@ -0,0 +1,334 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "core/common/inlined_containers.h" +#include "core/common/string_utils.h" +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class MemoryOptimizer + +Find recomputable subgraphs and enable according to user configs. +*/ + +class MemoryOptimizer : public GraphTransformer { + private: + using NodeOutputPort = std::pair; + using ActivationUsedMap = InlinedHashMap>; + + /** + * @brief Level to control allowed operations during subgraph detecting. + * Level 0: only allow cheap-to-compute operations. + * Level 1: allow more expensive operations. + */ + enum class ProbeLevel { + Basic = 0, + Advanced = 1, + LevelMax = 2, + }; + + /** + * @brief Type of memory reduction techniques. + */ + enum class OptimizationType { + None = 0, // Disabled. + Recompute = 1, + TypeMax = 2, + }; + + /** + * @brief Type of user config. + * type: type of memory reduction techniques. + * requested_count: the number of occurrences of a subgraph pattern for alleviation. -1 means apply all. + * One example: if a subgraph pattern is found 3 times, and requested_count is set 2, then the 1st and 2nd subgraph + * in topological order will be applied for alleviation. This is useful to avoid alleviating more memory than + * needed. + */ + struct UserConfig { + OptimizationType type; + int requested_count; + }; + + /** + * @brief Struct to store properties of a specific subgraph. + */ + struct SubGraphDesc { + SubGraphDesc() = default; + + // A string to represent the subgraph, used as a unique "ID" for a unique subgraph. + std::string subgraph_representative_str; + + InlinedHashMap shape_str_frequency; // shape string to frequency + UserConfig user_optimizer_config; + int total_frequency{0}; // The occurrence of this subgraph pattern in the graph. + + int applied_count{0}; // The number of times this subgraph pattern has been really applied in this transformer. + int skip_count{0}; // The number of times this subgraph instances will skipped in reversed topological order. + float saving_ratio{1.0f}; + }; + + /** + * @brief A struct to maintain the information of target subgraphs to optimize. + * Imagine we loop all nodes finding recomputable/offload-able subgraphs, we want to store them first. + * Afterwards, we optionally pick up some of them to apply optimization according to user configs. + * + * subgraph_descs is a map from subgraph string representation to its subgraph related configurations. + * + * _optimization_target_graphs_ is a map from activation producer node pointers to its target optimization subgraph + * nodes. For example, if a subgraph Cast+Gelu can be recomputed, we may have a map like: + * key: node pointer of stashed activation producer Gelu; value: node vector {Cast, Gelu,}. + * + * When we AddSubGraphInstance, we must provider its corresponding subgraph desc in the parameter. + * Then we can know for each subgraph instance, what's the subgraph str representation, and what's the optimization + * config. + */ + struct SubGraphStores { + /********************************** + ** subgraph desc section starts ** + **********************************/ + + size_t SubGraphDescCount() const { + return subgraph_descs.size(); + } + + bool Contains(std::string_view subgraph_str) const { + return subgraph_descs.find(subgraph_str) != subgraph_descs.end(); + } + + SubGraphDesc& GetSubGraphDesc(std::string_view subgraph_string) { + ORT_ENFORCE(Contains(subgraph_string), "Subgraph string not found.", subgraph_string); + return subgraph_descs.at(subgraph_string); + } + + SubGraphDesc& CreateSubGraphDesc(const std::string& subgraph_string, + UserConfig& config) { + ORT_ENFORCE(!Contains(subgraph_string), "Subgraph string already exists.", subgraph_string); + subgraph_descs[subgraph_string].user_optimizer_config = config; + subgraph_descs[subgraph_string].subgraph_representative_str = subgraph_string; + return subgraph_descs[subgraph_string]; + } + + /********************************************************************** + ** subgraph desc section ends, and subgraph instance section starts. ** + ***********************************************************************/ + + // Pair of . + using GraphInstanceInfo = std::pair, std::string>; + + void AddSubGraphInstance(const Node* node, + const InlinedVector& nodes_in_topological_order, + const SubGraphDesc& subgraph_desc) { + ORT_ENFORCE(_optimization_target_graphs_.find(node) == _optimization_target_graphs_.end()); + _optimization_target_graphs_[node] = std::make_pair(nodes_in_topological_order, + subgraph_desc.subgraph_representative_str); + } + + bool ContainsSubGraphInstance(const Node* node) const { + return _optimization_target_graphs_.find(node) != _optimization_target_graphs_.end(); + } + + GraphInstanceInfo& GetSubGraphInstance(const Node* node) { + ORT_ENFORCE(_optimization_target_graphs_.find(node) != _optimization_target_graphs_.end()); + return _optimization_target_graphs_[node]; + } + + /*********************************** + ** subgraph instance section ends ** + ***********************************/ + + InlinedHashMap subgraph_descs; + InlinedHashMap _optimization_target_graphs_; + }; + + /** + * @brief Used to define per-op recompute config. + * + */ + struct AllowedRecomputeNodeConfig { + InlinedVector input_arg_indices; // input index to iterate further (bottom up) + }; + + public: + MemoryOptimizer(const std::string& enable_memory_optimizer, const std::string& level) + : GraphTransformer("MemoryOptimizer") { + // Parse user defined configs. + ORT_ENFORCE(ParseConfigFromString(enable_memory_optimizer, level).IsOK()); + + RegisterAllowedRecomputeOps(); + } + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + bool ShouldOnlyApplyOnce() const override { return true; } + + private: + Status ParseConfigFromString(const std::string& enable_memory_optimizer, const std::string& level); + + /** + * @brief Prepare info including activation usage, node usage in fw and bw. + * + * @param graph Graph to iterate. + * @param fw_op_output_arg_used_map Collected activation usage mapping. + * - key: node arg name + * - value: a pair of bool, representing whether the activation is used by forward nodes or by backward nodes. + * @return int64_t value The boundary op (for example YieldOp) order in topological order. If no boundary op found, + * return -1; + */ + int64_t PrepareForTransformation(const Graph& graph, + ActivationUsedMap& fw_op_output_arg_used_map, + InlinedHashMap& + node_index_to_its_order_in_topological_sort_map) const; + /** + * @brief Find all stashed activations, e.g. activations used by forward operators and backward operators. + * + * @param graph Graph to iterate. + * @param fw_op_output_arg_used_map Activation usage mapping. + * @param candidate_output_args_map Candidate activations, which are consumed by both fw and bw ops. + * @return Status + */ + Status GetStashedActivationCandidates( + const Graph& graph, + const InlinedHashMap>& fw_op_output_arg_used_map, + InlinedHashMap>& candidate_output_args_map, + const logging::Logger& logger) const; + + /** + * @brief Apply graph modifications based on user configs. + * + * @param graph Graph to iterate and modify. + * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. + * Used to re-order the collected subgraph nodes. + * @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and + * bw ops. + * @param logger Logger. + * @param boundary_op_order_in_topological_sort index of the boundary op between fw and bw. + * @param subgraph_stores A store to maintain all found subgraphs. + * @param node The node we used to look for corresponding optimization graphs. + * @return true + * @return false + */ + bool ModifyGraph(Graph& graph, + const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, + const InlinedHashMap>& candidate_output_args_map, + const logging::Logger& logger, + int64_t boundary_op_order_in_topological_sort, + SubGraphStores& subgraph_stores, + Node* node) const; + + /** + * @brief Convert the recompute subgraph to its string representation. + * + * @param nodes_in_topological_order The subgraph nodes in topological order. + * @param subgraph_string_representation Returns subgraph string representation. + * @param log_info Returns log info for users. + */ + void NodesInTopoOrderToString(const InlinedVector& nodes_in_topological_order, + std::string& subgraph_string_representation, + std::string& log_info) const; + + /** + * @brief Convert optimization type to string. + */ + std::string UserConfigToString(const UserConfig& config) const; + + /** + * @brief Summarize transformation details. + * + * @param stashed_activation_statistics statistics around stashed activation memory saving. + * @return void + */ + void PrintSummary(const SubGraphStores& recompute_stores, + const SubGraphStores& recompute_with_compromise_stores, + const logging::Logger& logger) const; + + /************************************************** + ** Recompute related function definition starts ** + *************************************************/ + + void RegisterAllowedRecomputeOps(); + + /** + * @brief Find recomputable subgraphs (has at least one nodes, at most MAXIMUM_RECOMPUTE_NODE_COUNT nodes). + * + * @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. + * @param node_output_index_candidates Candidate output indices of "node", which are consumed by both fw and bw ops. + * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. + * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. + * Used to re-order the collected subgraph nodes. + * @param nodes_in_topological_order Collected vector of nodes of found subgraph, in the order of the topological + * sorted. + * @param logger Logger. + * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a + * recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the + * size of stashed activation. + * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a + * compromised subgraph. + * @return Status + */ + Status SelectRecomputeSubgraph(const Node& node, + const InlinedVector& node_output_index_candidates, + const ActivationUsedMap& fw_op_output_arg_used_map, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + InlinedVector& nodes_in_topological_order, + const logging::Logger& logger, + bool compromise_stashed_activation, + bool& can_compromise_stashed_activation) const; + + /** + * @brief For the node producing stashed activation, check whether a recomputable subgraph can be found or not. + * + * @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. + * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. + * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. + * Used to re-order the collected subgraph nodes. + * @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and + * bw ops. + * @param subgraph_stores A store to maintain all found subgraphs. + * @param logger Logger. + * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a + * recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the + * size of stashed activation. + * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a + * compromised subgraph. + */ + void CheckNodeForRecompute(const Node& node, + const ActivationUsedMap& fw_op_output_arg_used_map, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const InlinedHashMap>& + candidate_output_args_map, + SubGraphStores& subgraph_stores, + const logging::Logger& logger, + bool compromise_stashed_activation, + bool& can_compromise_stashed_activation) const; + + /** + * @brief Duplicate nodes to create a recompute subgraph. + * + * @param graph Graph to iterate. + * @param nodes_in_topological_order Subgraph nodes to recompute. + * @param recompute_subgraph_output_node The final node of the subgraph. + * @return Status + */ + Status CreateRecomputeGraph(Graph& graph, + const InlinedVector& nodes_in_topological_order, + Node*& recompute_subgraph_output_node) const; + + /************************************************** + ** Recompute related function definition ends ** + *************************************************/ + + // The op types that are supported predefined. + InlinedHashMap recomputable_op_type_to_input_arg_index_map_; + // User enabled map of the subgraph string representation to the alleviation type. + InlinedHashMap pattern_subgraph_to_user_optimizer_config_map_; + std::string optimizer_config_; + ProbeLevel recompute_probe_level_; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index a18f2347b0..82c00bc47d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -296,6 +296,11 @@ class GraphExecutionManager(GraphExecutionInterface): session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. session_options.log_severity_level = int(self._debug_options.logging.log_level) + # Disable memory alleviation by default. Allow user to enable it via environment variable. + alleviation_config = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True) + probe_level = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_PROBE_RECOMPUTE_LEVEL", "1", warn=True) + session_options.add_session_config_entry("optimization.enable_memory_optimizer", alleviation_config) + session_options.add_session_config_entry("optimization.enable_memory_probe_recompute_level", probe_level) if self._debug_options.save_onnx_models.save: session_options.optimized_model_filepath = os.path.join( diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc new file mode 100644 index 0000000000..7a9c1a9015 --- /dev/null +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4244) +#endif + +#include +#include "core/graph/onnx_protobuf.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "asserts.h" +#include "core/common/span_utils.h" +#include "core/framework/data_types.h" +#include "core/framework/ort_value.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/optimizer/utils.h" +#include "core/platform/env.h" +#include "core/session/inference_session.h" +#include "core/util/math.h" +#include "test/framework/test_utils.h" +#include "test/capturing_sink.h" +#include "test/test_environment.h" +#include "test/util/include/asserts.h" +#include "orttraining/core/optimizer/memory_optimizer.h" + +using namespace std; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace test { + +#define MODEL_FOLDER ORT_TSTR("testdata/transform/recompute/") + +TEST(MemoryOptimizerTests, GeluRecompute) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "recompute_gelu.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Gemm"] == 5); + ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.YieldOp"] == 1); + ASSERT_TRUE(op_to_count["ReduceSum"] == 2); + ASSERT_TRUE(op_to_count["com.microsoft.GeluGrad"] == 1); + + std::string gelu_node_name; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Gelu") == 0) { + gelu_node_name = node.Name(); + break; + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + + const std::string alleviation_config("Gelu+:1:-1"); + const std::string alleviation_level("1"); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(alleviation_config, alleviation_level), TransformerLevel::Level3)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Gemm"] == 5); + ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 2); + ASSERT_TRUE(op_to_count["com.microsoft.YieldOp"] == 1); + ASSERT_TRUE(op_to_count["ReduceSum"] == 2); + ASSERT_TRUE(op_to_count["com.microsoft.GeluGrad"] == 1); + + Node* recompute_gelu_node{nullptr}; + Node* original_gelu_node{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Gelu") == 0) { + if (node.Name() != gelu_node_name) { + recompute_gelu_node = &node; + } else { + original_gelu_node = &node; + } + } + } + + ASSERT_EQ(recompute_gelu_node->MutableInputDefs()[0]->Name(), original_gelu_node->MutableInputDefs()[0]->Name()); + ASSERT_EQ(recompute_gelu_node->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); + ASSERT_EQ(original_gelu_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); +} + +TEST(MemoryOptimizerTests, TileRecompute) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto model_uri = MODEL_FOLDER "recompute_tile.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Tile"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.YieldOp"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 3); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + + const std::string alleviation_config("Tile+:1:-1"); + const std::string alleviation_level("1"); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(alleviation_config, alleviation_level), TransformerLevel::Level3)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Tile"] == 2); + ASSERT_TRUE(op_to_count["com.microsoft.YieldOp"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 3); + + Node* recompute_tile_node{nullptr}; + Node* original_tile_node{nullptr}; + for (auto& node : graph.Nodes()) { + if (node.Priority() == static_cast(ExecutionPriority::LOCAL_LOW)) { + if (node.OpType().compare("Tile") == 0) { + recompute_tile_node = &node; + } + } else if (node.Priority() == static_cast(ExecutionPriority::DEFAULT)) { + if (node.OpType().compare("Tile") == 0) { + original_tile_node = &node; + } + } + } + + const Node* query_layer_grad_node = graph.GetProducerNode("query_layer_grad"); + + ASSERT_TRUE(recompute_tile_node); + ASSERT_TRUE(original_tile_node); + ASSERT_TRUE(query_layer_grad_node); + + ASSERT_EQ(recompute_tile_node->MutableInputDefs()[0]->Name(), original_tile_node->MutableInputDefs()[0]->Name()); + ASSERT_EQ(query_layer_grad_node->InputDefs()[1]->Name(), recompute_tile_node->MutableOutputDefs()[0]->Name()); + + ASSERT_EQ(recompute_tile_node->Priority(), static_cast(ExecutionPriority::LOCAL_LOW)); + ASSERT_EQ(original_tile_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); + ASSERT_EQ(query_layer_grad_node->Priority(), static_cast(ExecutionPriority::DEFAULT)); +} + +} // namespace test +} // namespace onnxruntime