diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index 03e94cefd0..716eed1afe 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -66,7 +66,14 @@ inline std::string TrimString(std::string s) { } /** - * So use this simple hash to generate unique int by given string input. + * @brief A consistent way to construct the full qualified op name. + */ +inline std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) { + return MakeString(domain, "::", op_type); +} + +/** + * Use this simple hash to generate unique int by given string input. */ inline uint32_t GetHashFromString(const std::string& str_value) { uint32_t hash = 0; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc index 9c98ed6d3e..1516fb37a7 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc @@ -4,6 +4,7 @@ #ifdef ENABLE_TRAINING #include +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" @@ -26,38 +27,38 @@ UpStreamGatherGraphTransformer::UpStreamGatherGraphTransformer( // 2. Whether the outputs have the same dim changes if the Gather node moves before that operator. // 3. Should all inputs be allowed when tracking back further (bottom-up); // if not, add the input index restriction as MatMul did. - {GetFullQualifiedOpName("Add", kOnnxDomain), + {utils::GetFullQualifiedOpName("Add", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("BiasGelu", kMSDomain), + {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), OpPassThroughConfig(std::make_shared>(), opset_1)}, - {GetFullQualifiedOpName("Cast", kOnnxDomain), + {utils::GetFullQualifiedOpName("Cast", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_19_13_9_6_1)}, - {GetFullQualifiedOpName("Div", kOnnxDomain), + {utils::GetFullQualifiedOpName("Div", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("Dropout", kOnnxDomain), + {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_13_12_10_7_6_1)}, - {GetFullQualifiedOpName("Gelu", kMSDomain), + {utils::GetFullQualifiedOpName("Gelu", kMSDomain), OpPassThroughConfig(std::make_shared>(), opset_1)}, {// Be noted, this is our own implementation of ONNX domain op. - GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_1)}, - {GetFullQualifiedOpName("MatMul", kOnnxDomain), + {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_9_1)}, - {GetFullQualifiedOpName("Reshape", kOnnxDomain), + {utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_19_14_13_5_1)}, - {GetFullQualifiedOpName("Softmax", kOnnxDomain), + {utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_11_1)}, - {GetFullQualifiedOpName("Transpose", kOnnxDomain), + {utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_1)}, }); @@ -69,7 +70,7 @@ bool UpStreamGatherGraphTransformer::UpStreamInternal( const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const { Node& slice_node = *info.node_ptr; - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); std::unordered_map propagate_input_indices; std::unordered_map> all_input_cmp_rets; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc index f7b48de2ca..716988e933 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc @@ -4,6 +4,7 @@ #ifdef ENABLE_TRAINING #include "core/framework/tensorprotoutils.h" +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/utils.h" #include "core/optimizer/compute_optimizer/upstream_reshape_actors.h" @@ -21,23 +22,23 @@ UpStreamReshapeGraphTransformer::UpStreamReshapeGraphTransformer( // If optype is not enough to guarantee the equivalence, we need to add a customized pre-check function. // 2. Should all inputs be allowed when tracking back further (bottom-up); // if not, add the input index restriction. - {GetFullQualifiedOpName("Add", kOnnxDomain), + {utils::GetFullQualifiedOpName("Add", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("BiasGelu", kMSDomain), + {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), OpPassThroughConfig( std::make_shared>(), opset_1)}, - {GetFullQualifiedOpName("Cast", kOnnxDomain), + {utils::GetFullQualifiedOpName("Cast", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_19_13_9_6_1)}, - {GetFullQualifiedOpName("Dropout", kOnnxDomain), + {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_13_12_10_7_6_1)}, {// Be noted, this is our own implementation of ONNX domain op. - GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), OpPassThroughConfig( std::make_shared(), opset_1)}, - {GetFullQualifiedOpName("MatMul", kOnnxDomain), + {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), OpPassThroughConfig( std::make_shared(), opset_13_9_1)}, }); @@ -47,7 +48,7 @@ bool UpStreamReshapeGraphTransformer::UpStreamInternal( Graph& graph, std::deque& queue, Node& current_node, ReshapeInfo& info, const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const { - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); std::vector propagate_input_indices; std::unordered_map> all_input_cmp_rets; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc index f08e37296d..4582f26a7d 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc @@ -5,6 +5,7 @@ #include #include "core/common/safeint.h" +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" @@ -130,7 +131,7 @@ template bool UpStreamGraphTransformerBase::Upstream(Graph& graph, std::deque& queue, Node& current_node, T1& info, const logging::Logger& logger) const { - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); if (allowed_passthrough_ops_.count(op_type)) { auto& pass_through_config = allowed_passthrough_ops_.at(op_type); LOG_DEBUG_INFO(logger, "Enter reorder handle for node " + current_node.Name() + "(" + op_type + ")"); diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h index 6e22fc791a..d848a03c55 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h @@ -72,13 +72,6 @@ class UpStreamGraphTransformerBase : public GraphTransformer { const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const = 0; - /** - * @brief A consistent way to construct the full qualified op name. - */ - std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) const { - return domain + "::" + op_type; - } - std::unordered_map> allowed_passthrough_ops_; private: diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 76b3325f36..b421eb2ab3 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -48,75 +48,352 @@ float InputOutputSizeRatio(const Node* node) { return 1.0f; } +using IgnorableInputIndices = InlinedVector; +using OpsetToIgnorableIndicesMap = InlinedHashMap; + /** - * @brief Used to define per-op recompute config. + * @brief Get the Allowed Recompute Ops object * + * The supported op types are predefined. + * Most recent revisited for ONNX v1.15.0 release - https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/docs/Operators.md + * + * We defined supported list explicitly instead of using a excluding list for the following reasons: + * 1. Some ops generate indeterministic results (for example using random number generator). We need evaluate whether + * this is a problem for recompute before adding the support, instead of fixing this after we find and try to + * fix convergence issues (which will be very hard if we have multiple indeterministic operators by default supported.) + * 2. Some ops schema will be changed in new opsets, we need also check manually whether it is applicable to recompute + * or not. + * 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not. */ -struct AllowedRecomputeNodeConfig { - InlinedVector input_arg_indices; // input index to iterate further (bottom up) -}; - -// The supported op types are predefined. - -const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { - static InlinedHashMap> recomputable_op_table_map; +const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { + static InlinedHashMap> recomputable_op_table_map; if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) { return recomputable_op_table_map.at(probe_op_level); } - recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()}); + recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()}); auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level); if (probe_op_level >= static_cast(ProbeLevel::Basic)) { recomputable_op_table.insert({ - // Binary elementwise - {"Add", AllowedRecomputeNodeConfig{{0, 1}}}, - {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Div", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Equal", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Mul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Sub", AllowedRecomputeNodeConfig{{0, 1}}}, + { + utils::GetFullQualifiedOpName("Add", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {9, {}}, + {14, {}}, + {15, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain), + { + {1, {1, 2}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Cast", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {9, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain), + { + {1, {}}, - // Data layout - /// The shape input is trivial whether it exists or not in backward. - {"Reshape", AllowedRecomputeNodeConfig{{0}}}, - {"Shape", AllowedRecomputeNodeConfig{{0}}}, - {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, - {"Transpose", AllowedRecomputeNodeConfig{{0}}}, - {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}}, + }, + }, + { + utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain), + { + {9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor + {20, {0}}, + }, + }, + { + utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), + { + // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. + {12, {1, 2}}, // ignore ratio and training_mode + {13, {1, 2}}, + }, + }, + { + utils::GetFullQualifiedOpName("Div", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Expand", kOnnxDomain), + { + {8, {1}}, // Ignore the shape. + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Cos", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), + { + // The axis input is trivial + {11, {1}}, + {14, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Einsum", kOnnxDomain), + { + {12, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Equal", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {11, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FastGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gather", kOnnxDomain), + { + {1, {1}}, // ignore the indices + {11, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kOnnxDomain), + { + {20, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Less", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Mul", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Range", kOnnxDomain), + { + {11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars. + }, + }, + { + utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), + { + {1, {}}, + {5, {}}, // ignore the shape. + {13, {}}, + {14, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Sin", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Slice", kOnnxDomain), + { + {1, {}}, + {10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional) + {11, {1, 2, 3, 4}}, + {13, {1, 2, 3, 4}}, + }, + }, + { + utils::GetFullQualifiedOpName("Split", kOnnxDomain), + { + {1, {1}}, // ignore split (optional) + {2, {}}, + {11, {}}, + {13, {1}}, // ignore the split (optional) + {18, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Sub", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Tile", kOnnxDomain), + { + {1, {1, 2}}, + {6, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), + { + {1, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Trilu", kOnnxDomain), + { + {14, {1}}, // ignore k (optional) + }, + }, + { + utils::GetFullQualifiedOpName("QuickGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Where", kOnnxDomain), + { + {9, {}}, + {16, {}}, + }, + }, - // Unary elementwise - {"Dropout", AllowedRecomputeNodeConfig{{0}}}, - {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, - /// The ratio and mode input are trivial whether they exist or not in backward - {"BitmaskDropout", AllowedRecomputeNodeConfig{{0}}}, - /// The axis input is trivial whether it exists or not in backward - {"CumSum", AllowedRecomputeNodeConfig{{0}}}, - {"Expand", AllowedRecomputeNodeConfig{{0}}}, - {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, - {"Gelu", AllowedRecomputeNodeConfig{{0}}}, - {"QuickGelu", AllowedRecomputeNodeConfig{{0}}}, - - // Ternary elementwise - {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, - - // Data copy - {"Tile", AllowedRecomputeNodeConfig{{0}}}, - {"Cast", AllowedRecomputeNodeConfig{{0}}}, - {"ConcatTraining", AllowedRecomputeNodeConfig{{0, 1}}}, // Input could be more than 2. But mostly 2. - {"Slice", AllowedRecomputeNodeConfig{{0}}}, - {"Split", AllowedRecomputeNodeConfig{{0}}}, - {"Gather", AllowedRecomputeNodeConfig{{0}}}, }); } if (probe_op_level >= static_cast(ProbeLevel::Advanced)) { recomputable_op_table.insert({ - {"LayerNormalization", AllowedRecomputeNodeConfig{{0, 1, 2}}}, - {"MatMul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"FusedMatMul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Softmax", AllowedRecomputeNodeConfig{{0}}}, - {"BiasSoftmax", AllowedRecomputeNodeConfig{{0, 1}}}, - {"BiasSoftmaxDropout", AllowedRecomputeNodeConfig{{0, 1}}}, + { + utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain), + { + {1, {2}}, // ignore ratio (optional) + }, + }, + { + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have LayerNormalization, + // while our contrib op defined LayerNormalization in opset 1 in ONNX domain. + {1, {}}, + {17, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), + { + {1, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {}}, + }, + }, }); } @@ -127,8 +404,20 @@ const InlinedHashMap& GetAllowedRecompu * @brief Check whether a node is a recomputable node at given probe level. */ bool IsRecomputable(const Node& node, ProbeLevel probe_level) { - const auto& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); - return op_table.find(node.OpType()) != op_table.end(); + const InlinedHashMap& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); + auto it = op_table.find(utils::GetFullQualifiedOpName(node.OpType(), node.Domain())); + if (it == op_table.end()) { + return false; + } + return it->second.count(node.SinceVersion()); +} + +const InlinedVector& GetIgnorableInputIndices(const Node& node, ProbeLevel probe_level) { + const InlinedHashMap& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); + auto it = op_table.find(utils::GetFullQualifiedOpName(node.OpType(), node.Domain())); + ORT_ENFORCE(it != op_table.end(), "Cannot get ignorable indices since the node type is supported in the list."); + ORT_ENFORCE(it->second.count(node.SinceVersion()) > 0, "Cannot get ignorable indices since the opset is supported"); + return it->second.at(node.SinceVersion()); } /** @@ -163,7 +452,6 @@ Status SelectRecomputeSubgraph(const Node& entry_node, bool& can_compromise_stashed_activation, float& save_ratio) { const ProbeLevel probe_level = probe_config.probe_level; - const auto& recomputable_op_table = GetAllowedRecomputeOps(static_cast(probe_level)); can_compromise_stashed_activation = false; @@ -213,7 +501,7 @@ Status SelectRecomputeSubgraph(const Node& entry_node, // If current op is NOT in allowed list: // 1). the output does not exist in backward, we cannot find a good solution for so, the search terminates. // 2). the output is used in backward, we don't need to trace back further, so continue searching. - auto op_recompute_config_it = recomputable_op_table.find(curr_node->OpType()); + bool is_recomputable = IsRecomputable(*curr_node, probe_level); auto cur_output_arg_name = curr_node->OutputDefs()[p.second]->Name(); if (is_first_queue_scan) { // We handle the entry node outputs differently because, we don't want this case falls into and succeed one of @@ -221,14 +509,14 @@ Status SelectRecomputeSubgraph(const Node& entry_node, // 1. "op is not in recompute op list, but its output is used in backward" // 2. "op is in recompute op list, but its output is used in backward" // (either of the above checks is true for entry node outputs) - if (op_recompute_config_it == recomputable_op_table.end()) { + if (!is_recomputable) { early_stop = true; MO_LOG_DEBUG_INFO(logger, "Entry Node " + curr_node->Name() + "(" + curr_node->OpType() + ") is **NOT** in recompute op list, search terminates."); break; } } else { - if (op_recompute_config_it == recomputable_op_table.end()) { + if (!is_recomputable) { if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + ") is **NOT** in recompute op list, but its output [" + @@ -283,14 +571,14 @@ Status SelectRecomputeSubgraph(const Node& entry_node, } // Iterate all input nodes according to allowed input arg index of the entry node. - const auto& input_arg_indices = op_recompute_config_it->second.input_arg_indices; + const auto& igorable_input_arg_indices = GetIgnorableInputIndices(*curr_node, probe_level); for (auto it = curr_node->InputEdgesBegin(), end = curr_node->InputEdgesEnd(); it != end; ++it) { const Node::EdgeEnd& input_edge = *it; const auto& parent_node = input_edge.GetNode(); const auto parent_node_output_index = input_edge.GetSrcArgIndex(); const auto current_node_input_index = input_edge.GetDstArgIndex(); - if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != - input_arg_indices.end()) { + if (std::find(igorable_input_arg_indices.begin(), igorable_input_arg_indices.end(), current_node_input_index) == + igorable_input_arg_indices.end()) { // If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue. auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape(); if (output_shape) {