From bf32dbbd9b61a5f8f2df85c3804165baee8a2e4b Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 14 Apr 2023 22:41:07 +0800 Subject: [PATCH] Share more constant initializers (#15461) ### Share more constant initializers. `ConstantSharing` transformer originally only handle single value initializer (scalar or 1D). This PR tried to share more cases to make common subexpression elimination transformer to remove more duplicated nodes. Originally, we used a single vector> to store different scalar values. In this PR, we create a unordered map with its key being data_type + rank + element count, and its value is a vector of `InitializerValue`. For one specific initializer, if it fulfils the condition, then finally will find the corresponding vector of `InitializerValue` by its , then search from the vector whether the constant tensor already exist or not. After that, a value id is returned, which will be combined together with to form the pattern key to decide which tensor to reuse (legacy code). ### Motivation and Context One example we see here is: ```mermaid stateDiagram [*] --> LayerNorm(b,s,64) LayerNorm(b,s,64) --> Reshape1 Shape1_Const[b*s,64] --> Reshape1 LayerNorm(b,s,64) --> Reshape2 Shape2_Const[b*s,64] --> Reshape2 Reshape1 --> AttentionSubGraph Reshape2 --> Add AttentionSubGraph--> Add Add --> [*] ``` Ideally CommonSubexpressionElimination can remove one of `Reshape1` and `Reshape2`, while since `Shape1_Const` and `Shape2_Const` are different NodeArg*, so it did not remove the duplication. This is an example: removing the duplication will bring more opportunities to apply graph transformations. --- .../core/optimizer/constant_sharing.cc | 136 ++++++--- .../core/optimizer/graph_transformer_utils.cc | 18 +- onnxruntime/core/optimizer/utils.cc | 14 +- .../python/tools/quantization/quant_utils.py | 5 +- .../test/framework/inference_session_test.cc | 33 ++- .../test/optimizer/graph_transform_test.cc | 276 ++++++++++++++++++ 6 files changed, 415 insertions(+), 67 deletions(-) diff --git a/onnxruntime/core/optimizer/constant_sharing.cc b/onnxruntime/core/optimizer/constant_sharing.cc index fa9a309098..c06349ec9b 100644 --- a/onnxruntime/core/optimizer/constant_sharing.cc +++ b/onnxruntime/core/optimizer/constant_sharing.cc @@ -29,20 +29,42 @@ bool IsSupportedDataType(int32_t data_type) { using SupportedTypeList = boost::mp11::mp_list; -bool IsValidSingleValueShape(const ONNX_NAMESPACE::TensorShapeProto* input_shape) { +// A threshold is defined here to restrict the graph transformation only applied to small tensors. +// Be note: having a bigger threshold means more overhead when we do the graph transformations. +// `8` is chosen to cover common constant use cases in some Reshape/Gather/Concat's inputs. +// TODO(pengwa): we can gradually increase this threshold if we see more benefits (memory saving +// or more CSE optimizations triggered). Should be careful to cover test cases that assume initializer +// name did not change after transformation then. +static constexpr int64_t TENSOR_ELEM_COUNT_THRESHOLD = 8; +static constexpr char SHARED_INITIALIZER_PREFIX[] = "ortshared_"; + +bool IsAllowedToShare(const ONNX_NAMESPACE::TensorShapeProto* input_shape, + int64_t& num_elements) { if (input_shape == nullptr) return false; size_t dim_size = static_cast(input_shape->dim_size()); - return dim_size == 0 || - (dim_size == 1 && utils::HasDimValue(input_shape->dim(0)) && input_shape->dim(0).dim_value() == 1); + num_elements = 1; + for (size_t i = 0; i < dim_size; ++i) { + auto dim = input_shape->dim(static_cast(i)); + if (!utils::HasDimValue(dim)) { + return false; + } + + int64_t dim_value = dim.dim_value(); + num_elements *= dim_value; + if (num_elements > TENSOR_ELEM_COUNT_THRESHOLD) { + return false; + } + } + + if (num_elements > 0 && num_elements <= TENSOR_ELEM_COUNT_THRESHOLD) { + return true; + } + + return false; } -static constexpr char SHARED_INITIALIZER_PREFIX[] = "ortshared_"; -bool IsSharedInitializer(std::string_view initializer_name) { - return initializer_name.rfind(SHARED_INITIALIZER_PREFIX, 0) == 0; -} - -// Return true when initializer node arg is consumed by any node conaining sub graphs; +// Return true when initializer node arg is consumed by any node containing sub graphs; // Otherwise, return false. bool PrepareInputPortsToReplace(Graph& graph, const NodeArg* origin_initializer_node_arg, InlinedHashMap>& consumer_node_to_input_ports_map) { @@ -57,7 +79,7 @@ bool PrepareInputPortsToReplace(Graph& graph, const NodeArg* origin_initializer_ } // Iterate all input defs to replace those that are equal to origin_initializer_node_arg, - // Then it would be safe to remove the consumer node aferwards. + // Then it would be safe to remove the consumer node afterwards. for (int i = 0; i < static_cast(const_node->InputDefs().size()); ++i) { if (const_node->InputDefs()[i] == origin_initializer_node_arg) { consumer_node_to_input_ports_map[const_node].push_back(i); @@ -98,40 +120,64 @@ void ReplaceInputsToUseSharedInitializer(Graph& graph, } /** - * @brief Get value unique id from constant store. + * @brief Initializer value representation, which is used to store and compare initializer values. * - * @tparam T Type of value to parse value from initializer. + * Two instances of InitializerValue are equal when: + * 1. data type match. + * 2. data rank match. + * 3. shape match. + * 4. value exactly match. + */ +struct InitializerValue { + InitializerValue(const ONNX_NAMESPACE::TensorProto* tensor_proto, Graph& graph) + : initializer{*tensor_proto, graph.ModelPath()} { + } + + bool operator==(const InitializerValue& other) const { + if (initializer.data_type() == other.initializer.data_type() && // data type + initializer.dims().size() == other.initializer.dims().size() && // rank + SpanEq(initializer.dims(), other.initializer.dims())) { // shape + return SpanEq(initializer.DataAsByteSpan(), other.initializer.DataAsByteSpan()); + } + + return false; + } + + bool operator!=(const InitializerValue& other) const { + return !(*this == other); + } + + Initializer initializer; +}; + +/** + * @brief Get value unique id from constant store. * * If the value parsed from initializer exists in constant store, then return the index in the container; * Otherwise, insert the value into container, return the last index. */ -template -struct GetOrAddValueInConstantStoreDispatcher { - size_t operator()(const onnxruntime::Initializer& initializer, - InlinedVector>& - const_value_store) const { - std::variant value; - if (std::is_same::value) { - value = math::halfToFloat(initializer.data()->val); - } else { - value = *initializer.data(); - } +size_t GetOrAddValueInConstantStore( + std::unique_ptr initializer, + InlinedHashMap>>& const_value_store, + const std::string& data_store_key) { + auto IsInitializerValueEqual = [&initializer](const std::unique_ptr& v) -> bool { + return *v == *initializer; + }; - auto it = std::find(const_value_store.begin(), const_value_store.end(), value); - if (it == const_value_store.end()) { - const_value_store.push_back(value); - return const_value_store.size() - 1; - } - return it - const_value_store.begin(); + auto& data_store = const_value_store[data_store_key]; + auto it = std::find_if(data_store.begin(), data_store.end(), IsInitializerValueEqual); + if (it == data_store.end()) { + data_store.emplace_back(std::move(initializer)); + return data_store.size() - 1; } -}; + return it - data_store.begin(); +} } // namespace Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const { int shared_count = 0; - // Accumulated map from type/value/rank to initializer: // > The key is a string representation of initializer's data type, value and rank. // > The value is newly created initializer NodeArg* to be shared. @@ -140,25 +186,25 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve InlinedVector original_initializer_names; original_initializer_names.reserve(initialized_tensor_set.size()); for (const auto& entry : initialized_tensor_set) { - // Ignore if the initializer exists in graph output, already handled, + // Ignore if the initializer exists in graph output, // or not a constant initializer (implicitly excludes the graph input). - if (IsSharedInitializer(entry.first) || - !graph_utils::IsConstantInitializer(graph, entry.first) || + if (!graph_utils::IsConstantInitializer(graph, entry.first) || graph.IsOutput(graph.GetNodeArg(entry.first)) || excluded_initializers_.find(entry.first) != excluded_initializers_.end()) { continue; } - original_initializer_names.push_back(entry.first); } // Avoid using the scalar value directly in pattern_key because the value for example INT_MAX can be super big // and it will be hard to read. Instead, a constant value store is maintained, then the value index is used as the // value unique id when construct pattern key. - InlinedVector> const_value_store; + InlinedHashMap>> const_value_store; for (const auto& initializer_name : original_initializer_names) { NodeArg* origin_initializer_node_arg = graph.GetNodeArg(initializer_name); - if (origin_initializer_node_arg == nullptr || !IsValidSingleValueShape(origin_initializer_node_arg->Shape())) { + int64_t num_elements = 1; + if (origin_initializer_node_arg == nullptr || + !IsAllowedToShare(origin_initializer_node_arg->Shape(), num_elements)) { continue; } @@ -168,7 +214,6 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve if (!tensor_proto || !IsSupportedDataType(tensor_proto->data_type())) { continue; } - // A map used to collect those consumers who have inputs use origin_initializer_node_arg. // > The key is consumer Node pointer. // > The value is a list of indices for the consumer Nodes' input (that used origin_initializer_node_arg). @@ -178,14 +223,18 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve if (found_subgraph_usage || consumer_node_to_input_ports_map.size() == 0) { continue; } + const std::string data_store_key = MakeString(tensor_proto->data_type(), + "_", origin_initializer_node_arg->Shape()->dim_size(), + "_", num_elements); - onnxruntime::Initializer initializer{*tensor_proto, graph.ModelPath()}; - utils::MLTypeCallDispatcherFromTypeList t_disp(tensor_proto->data_type()); - size_t value_id = t_disp.InvokeRet(initializer, const_value_store); + std::unique_ptr init_value = std::make_unique(tensor_proto, graph); + // The constant value store contains multiple buckets, indexed by data_store_key. + // For each initializer, we will check which bucket it belongs to, + // then add the value into the bucket if it does not exits; or get the index within the bucket if it already exists. + size_t value_id = GetOrAddValueInConstantStore(std::move(init_value), const_value_store, data_store_key); // Construct a string by data type, value, and rank. Used as a key in pattern_key_to_shared_arg_map. - const std::string pattern_key = MakeString(SHARED_INITIALIZER_PREFIX, value_id, "_", tensor_proto->data_type(), "_", - origin_initializer_node_arg->Shape()->dim_size()); + const std::string pattern_key = MakeString(SHARED_INITIALIZER_PREFIX, data_store_key, "_", value_id); // If there is no such existing scalar pattern, add a new one. if (pattern_key_to_shared_arg_map.find(pattern_key) == pattern_key_to_shared_arg_map.end()) { @@ -206,7 +255,6 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve } LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count; - return Status::OK(); } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index d1452eb47b..36b16ba27a 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -195,17 +195,23 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::move(rule_transformer)); } - // We need to remove the duplicated QDQ Pairs before all other GraphTransformation. - // no filtering on execution provider for L1 optimizations as they only use official ONNX operators + if (session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableDoubleQDQRemover, "0") == "0") { + // We need to remove the duplicated QDQ Pairs before all other GraphTransformation. + transformers.emplace_back(std::make_unique()); + } + // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by // default, CSE will not merge them, because the different initializers are represented by different NodeArg. - if (session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableDoubleQDQRemover, "0") == "0") { - transformers.emplace_back(std::make_unique()); + InlinedHashSet excluded_initializers; + excluded_initializers.reserve(session_options.initializers_to_share_map.size()); + for (const auto& p : session_options.initializers_to_share_map) { + excluded_initializers.insert(p.first); } - transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique(cpu_ep, excluded_initializers)); + transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq)); transformers.emplace_back(std::make_unique()); @@ -319,7 +325,7 @@ InlinedVector> GenerateTransformers( } #endif -#endif // !defined(DISABLE_CONTRIB_OPS) +#endif // !defined(DISABLE_CONTRIB_OPS) // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their // fusions might be prevented if this one removes a Q/DQ node too early. transformers.emplace_back(std::make_unique(enable_quant_qdq_cleanup)); diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 69d7a32090..c8e9779ff3 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -271,12 +271,24 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) { // so we have to assume that they are not deterministic, to be on the safe side. // We could also allow other known domains (kMSDomain, kMSNchwcDomain, kMSFeaturizersDomain), // as long as we verify which of their operations are non-deterministic and add them in the map below. -constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", "Multinomial"}; +constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNormal", "RandomUniformLike", + "RandomNormalLike", "Multinomial"}; + +#ifdef ENABLE_TRAINING_OPS +constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather"}; +#endif + bool IsOperationDeterministic(const std::string& domain, const std::string& op) { if (domain.compare(kOnnxDomain) == 0) { auto iter = std::find(kOnnxDomainNonDeterministicOps.begin(), kOnnxDomainNonDeterministicOps.end(), op); return iter == kOnnxDomainNonDeterministicOps.end(); } +#ifdef ENABLE_TRAINING_OPS + if (domain.compare(kMSDomain) == 0) { + auto iter = std::find(kMSDomainDeterministicOps.begin(), kMSDomainDeterministicOps.end(), op); + return iter != kMSDomainDeterministicOps.end(); + } +#endif // Unknown domain. Assume the op is not deterministic. return false; } diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 02d62dd41b..99460dd1de 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -505,7 +505,10 @@ def optimize_model(model_path: Path, opt_model_path: Path): sess_option = SessionOptions() sess_option.optimized_model_filepath = opt_model_path.as_posix() sess_option.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC - _ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"]) + kwargs = {} + # This will rename constant initializer names, disable it to make test pass. + kwargs["disabled_optimizers"] = ["ConstantSharing"] + _ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"], **kwargs) def add_pre_process_metadata(model): diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 704c4082fb..3083c0ccfd 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -716,11 +716,11 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { ASSERT_TRUE(lines[i].find(s) != string::npos); #ifdef USE_CUDA has_api_info = has_api_info || lines[i].find("Api") != string::npos && - lines[i].find("cudaLaunch") != string::npos; + lines[i].find("cudaLaunch") != string::npos; #endif #ifdef USE_ROCM has_api_info = has_api_info || lines[i].find("Api") != string::npos && - lines[i].find("hipLaunch") != string::npos; + lines[i].find("hipLaunch") != string::npos; #endif } } @@ -732,7 +732,6 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { #endif } - TEST(InferenceSessionTests, CheckRunProfilerWithStartProfile) { SessionOptions so; @@ -877,12 +876,12 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { ASSERT_TRUE(have_log_entry_with_vlog_session_msg); - //bool have_log_entry_with_vlog_run_msg = - // (std::find_if(msgs.begin(), msgs.end(), - // [&](std::string msg) { return msg.find("Size of execution plan vector") != string::npos; }) != - // msgs.end()); + // bool have_log_entry_with_vlog_run_msg = + // (std::find_if(msgs.begin(), msgs.end(), + // [&](std::string msg) { return msg.find("Size of execution plan vector") != string::npos; }) != + // msgs.end()); - //ASSERT_TRUE(have_log_entry_with_vlog_run_msg); + // ASSERT_TRUE(have_log_entry_with_vlog_run_msg); bool has_num_streams_msg = (std::find_if(msgs.begin(), msgs.end(), [&](std::string msg) { return msg.find("Number of streams") != string::npos; }) != msgs.end()); @@ -2778,15 +2777,19 @@ TEST(InferenceSessionTests, InitializerSharing_EnsureSessionsUseUserAddedInitial ASSERT_EQ(so1_init_buffer, so2_init_buffer); int so3_idx; - ASSERT_STATUS_OK(sess3.GetSessionState().GetOrtValueNameIdxMap().GetIdx(init_name, so3_idx)); - const auto* so3_init_buffer = sess3.GetSessionState().GetInitializedTensors().at(so3_idx).Get().Data(); + // If the original initializer name got changed by graph transformers, then we don't need check + // the data ptr reuse or not with other session. + if (sess3.GetSessionState().GetOrtValueNameIdxMap().GetIdx(init_name, so3_idx).IsOK()) { + const auto* so3_init_buffer = + sess3.GetSessionState().GetInitializedTensors().at(so3_idx).Get().Data(); - // Ensure session 3 doesn't share the same data ptr as any other session - ASSERT_NE(so3_init_buffer, so1_init_buffer); - ASSERT_NE(so3_init_buffer, so2_init_buffer); + // Ensure session 3 doesn't share the same data ptr as any other session + ASSERT_NE(so3_init_buffer, so1_init_buffer); + ASSERT_NE(so3_init_buffer, so2_init_buffer); - // Ensure session 3 doesn't share the same data ptr as the one supplied by the user for any of the other sessions - ASSERT_NE(so3_init_buffer, val_to_share.Get().Data()); + // Ensure session 3 doesn't share the same data ptr as the one supplied by the user for any of the other sessions + ASSERT_NE(so3_init_buffer, val_to_share.Get().Data()); + } } void RunModelWithDenormalAsZero(InferenceSession& session_object, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 0c4a685371..96ce3d5491 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -6202,6 +6202,130 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatOrHalfTypedInitialize } } +template +void BuildConstantSharingDivMulGraphFor2DInitializer(ModelTestBuilder& builder) { + auto* input0_arg = builder.MakeInput({{1, 1, 256, 8}}); + auto* input1_arg = builder.MakeInput({{1, 1, 256, 8}}); + auto* div_out = builder.MakeIntermediate(); + builder.AddNode("Div", {input0_arg, input1_arg}, {div_out}); + + std::vector values{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + std::vector values_float16; + values_float16.reserve(values.size()); + if (std::is_same::value) { + for (auto v : values) { + values_float16.push_back(MLFloat16(math::floatToHalf(v))); + } + } + + for (size_t i = 0; i < 12; ++i) { + NodeArg* mul_initializer = nullptr; + if (std::is_same::value) { + mul_initializer = builder.MakeInitializer({1, 8}, values_float16); + } else if (std::is_same::value) { + mul_initializer = builder.MakeInitializer({1, 8}, values); + } else { + ASSERT_TRUE(false); + } + auto* mul_out = builder.MakeOutput(); + builder.AddNode("Mul", {div_out, mul_initializer}, {mul_out}); + } +} + +/* +Test graph include multiple equivalent subgraphs as below. + graph input [1, 1, 256, 8] (float|MLFloat16) + | + Div + / | \ + / | \ + / ... | / ... \ + Mul Mul Mul + | | | + graph out [1, 1, 256, 8] (float|MLFloat16) + +Be noted: + the Mul's input initializer is a 2D float/MLFloat16. +*/ +TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatOrHalfTypedInitializer) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 2U); + TEST_RETURN_IF_NOT(op_count_pre["Div"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Mul"] == 12); + TEST_RETURN_IF_NOT(graph.GetAllInitializedTensors().size() == 12U); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + const InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors(); + TEST_RETURN_IF_NOT(initialized_tensor_set.size() == 1U); + const NodeArg* mul_initializer = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Mul") == 0) { + if (!mul_initializer) { + mul_initializer = node.InputDefs()[1]; + TEST_RETURN_IF(mul_initializer == nullptr); + TEST_RETURN_IF_NOT(mul_initializer->Shape()->dim_size() == 2); + } else { + TEST_RETURN_IF_NOT(mul_initializer == node.InputDefs()[1]); + } + } + } + TEST_RETURN_IF(mul_initializer == nullptr); + for (const auto& entry : initialized_tensor_set) { + if (entry.first.compare(mul_initializer->Name()) == 0) { + const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; + int32_t data_type = tensor_proto->data_type(); + onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(float_const.size() == 8); + for (int i = 0; i < 8; ++i) { + float float_const_value; + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + float_const_value = math::halfToFloat((float_const.data() + i)->val); + } else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + float_const_value = *(float_const.data() + i); + } else { + return Status(common::ONNXRUNTIME, common::FAIL, "unexpected type"); + } + TEST_RETURN_IF_NOT(float_const_value == i * 1.0f); + } + } + } + + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count.size() == 2U); + TEST_RETURN_IF_NOT(op_count["Div"] == 1); + TEST_RETURN_IF_NOT(op_count["Mul"] == 12); + return Status::OK(); + }; + + const std::vector opsets{12, 13, 14}; // Clip support int64_t since opset 12 + + // Float data type tests. + auto build_test_case_float = [&](ModelTestBuilder& builder) { + BuildConstantSharingDivMulGraphFor2DInitializer(builder); + }; + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case_float, opset_version, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + + // MLFloat16 data type tests. + auto build_test_case_mlfloat16 = [&](ModelTestBuilder& builder) { + BuildConstantSharingDivMulGraphFor2DInitializer(builder); + }; + + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case_mlfloat16, opset_version, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + /* Test graph include multiple equivalent subgraphs as below. graph input [1, 1, 256, 256] (float) @@ -6314,6 +6438,158 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatAndHalfTypedInitializ } } +/* +Test graph include multiple equivalent subgraphs as below. + graph input [1, 1, 8, 8] (float) + | + Div ______________________________ + / | \_______ | | + / | float | | | half | half + / ... | / ... | | | / ... | / ... + Mul Mul Sub Sub Add Add + | | | | \ / + graph out [1, 1, 8, 8](float) graph out [1, 1, 8, 8](MLFloat16) + +Be noted: + the Mul's input initializer is a 2D float tensor. + the Add's input initializer is a 2D MLFloat16 tensor. +*/ +TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatAndHalfTypedInitializer) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 5U); + TEST_RETURN_IF_NOT(op_count_pre["Div"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Cast"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Mul"] == 3); + TEST_RETURN_IF_NOT(op_count_pre["Sub"] == 3); + TEST_RETURN_IF_NOT(op_count_pre["Add"] == 3); + TEST_RETURN_IF_NOT(graph.GetAllInitializedTensors().size() == 9U); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + const InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors(); + TEST_RETURN_IF_NOT(initialized_tensor_set.size() == 3U); + const NodeArg* mul_initializer = nullptr; + const NodeArg* sub_initializer = nullptr; + const NodeArg* add_initializer = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Mul") == 0) { + if (!mul_initializer) { + mul_initializer = node.InputDefs()[1]; + TEST_RETURN_IF(mul_initializer == nullptr); + TEST_RETURN_IF_NOT(mul_initializer->Shape()->dim_size() == 2); + TEST_RETURN_IF_NOT(mul_initializer->Shape()->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(mul_initializer->Shape()->dim(1).dim_value() == 8); + } else { + TEST_RETURN_IF_NOT(mul_initializer == node.InputDefs()[1]); + } + } else if (node.OpType().compare("Sub") == 0) { + if (!sub_initializer) { + sub_initializer = node.InputDefs()[1]; + TEST_RETURN_IF(sub_initializer == nullptr); + TEST_RETURN_IF_NOT(sub_initializer->Shape()->dim_size() == 2); + TEST_RETURN_IF_NOT(sub_initializer->Shape()->dim(0).dim_value() == 8); + TEST_RETURN_IF_NOT(sub_initializer->Shape()->dim(1).dim_value() == 1); + } else { + TEST_RETURN_IF_NOT(sub_initializer == node.InputDefs()[1]); + } + } else if (node.OpType().compare("Add") == 0) { + if (!add_initializer) { + add_initializer = node.InputDefs()[1]; + TEST_RETURN_IF(add_initializer == nullptr); + TEST_RETURN_IF_NOT(add_initializer->Shape()->dim_size() == 2); + TEST_RETURN_IF_NOT(add_initializer->Shape()->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(add_initializer->Shape()->dim(1).dim_value() == 8); + } else { + TEST_RETURN_IF_NOT(add_initializer == node.InputDefs()[1]); + } + } + } + TEST_RETURN_IF(mul_initializer == nullptr); + TEST_RETURN_IF(sub_initializer == nullptr); + TEST_RETURN_IF(add_initializer == nullptr); + for (const auto& entry : initialized_tensor_set) { + const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; + int32_t data_type = tensor_proto->data_type(); + onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(float_const.size() == 8); + if (entry.first.compare(mul_initializer->Name()) == 0) { + TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + for (int i = 0; i < 8; ++i) { + float float_const_value = *(float_const.data() + i); + TEST_RETURN_IF_NOT(float_const_value == i * 1.0f); + } + } else if (entry.first.compare(sub_initializer->Name()) == 0) { + TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + for (int i = 0; i < 8; ++i) { + float float_const_value = *(float_const.data() + i); + TEST_RETURN_IF_NOT(float_const_value == i * 1.0f); + } + } else if (entry.first.compare(add_initializer->Name()) == 0) { + TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + for (int i = 0; i < 8; ++i) { + float float_const_value = math::halfToFloat((float_const.data() + i)->val); + TEST_RETURN_IF_NOT(float_const_value == i * 1.0f); + } + } + } + + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count.size() == 5U); + TEST_RETURN_IF_NOT(op_count["Div"] == 1); + TEST_RETURN_IF_NOT(op_count["Mul"] == 3); + TEST_RETURN_IF_NOT(op_count["Sub"] == 3); + TEST_RETURN_IF_NOT(op_count["Cast"] == 1); + TEST_RETURN_IF_NOT(op_count["Add"] == 3); + return Status::OK(); + }; + + const std::vector opsets{12, 13, 14}; + + std::vector values{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + std::vector values_float16; + values_float16.reserve(values.size()); + for (auto v : values) { + values_float16.push_back(MLFloat16(math::floatToHalf(v))); + } + + auto build_test_case_float = [&values, &values_float16](ModelTestBuilder& builder) { + auto* input0_arg = builder.MakeInput({{1, 1, 8, 8}}); + auto* input1_arg = builder.MakeInput({{1, 1, 8, 8}}); + auto* div_out = builder.MakeIntermediate(); + builder.AddNode("Div", {input0_arg, input1_arg}, {div_out}); + + for (size_t i = 0; i < 3; ++i) { + NodeArg* mul_initializer = builder.MakeInitializer({1, 8}, values); + auto* mul_out = builder.MakeOutput(); + builder.AddNode("Mul", {div_out, mul_initializer}, {mul_out}); + } + + for (size_t i = 0; i < 3; ++i) { + NodeArg* sub_initializer = builder.MakeInitializer({8, 1}, values); + auto* sub_out = builder.MakeOutput(); + builder.AddNode("Sub", {div_out, sub_initializer}, {sub_out}); + } + + auto* cast_out = builder.MakeIntermediate(); + builder.AddNode("Cast", {div_out}, {cast_out}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + for (size_t i = 0; i < 3; ++i) { + NodeArg* add_initializer = builder.MakeInitializer({1, 8}, values_float16); + auto* add_out = builder.MakeOutput(); + builder.AddNode("Add", {cast_out, add_initializer}, {add_out}); + } + }; + + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case_float, opset_version, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + /* Test graph include multiple equivalent subgraphs as below. graph input [1, 1, 256, 256] (float)