diff --git a/onnxruntime/core/optimizer/constant_sharing.cc b/onnxruntime/core/optimizer/constant_sharing.cc index fa9a309098..c06349ec9b 100644 --- a/onnxruntime/core/optimizer/constant_sharing.cc +++ b/onnxruntime/core/optimizer/constant_sharing.cc @@ -29,20 +29,42 @@ bool IsSupportedDataType(int32_t data_type) { using SupportedTypeList = boost::mp11::mp_list; -bool IsValidSingleValueShape(const ONNX_NAMESPACE::TensorShapeProto* input_shape) { +// A threshold is defined here to restrict the graph transformation only applied to small tensors. +// Be note: having a bigger threshold means more overhead when we do the graph transformations. +// `8` is chosen to cover common constant use cases in some Reshape/Gather/Concat's inputs. +// TODO(pengwa): we can gradually increase this threshold if we see more benefits (memory saving +// or more CSE optimizations triggered). Should be careful to cover test cases that assume initializer +// name did not change after transformation then. +static constexpr int64_t TENSOR_ELEM_COUNT_THRESHOLD = 8; +static constexpr char SHARED_INITIALIZER_PREFIX[] = "ortshared_"; + +bool IsAllowedToShare(const ONNX_NAMESPACE::TensorShapeProto* input_shape, + int64_t& num_elements) { if (input_shape == nullptr) return false; size_t dim_size = static_cast(input_shape->dim_size()); - return dim_size == 0 || - (dim_size == 1 && utils::HasDimValue(input_shape->dim(0)) && input_shape->dim(0).dim_value() == 1); + num_elements = 1; + for (size_t i = 0; i < dim_size; ++i) { + auto dim = input_shape->dim(static_cast(i)); + if (!utils::HasDimValue(dim)) { + return false; + } + + int64_t dim_value = dim.dim_value(); + num_elements *= dim_value; + if (num_elements > TENSOR_ELEM_COUNT_THRESHOLD) { + return false; + } + } + + if (num_elements > 0 && num_elements <= TENSOR_ELEM_COUNT_THRESHOLD) { + return true; + } + + return false; } -static constexpr char SHARED_INITIALIZER_PREFIX[] = "ortshared_"; -bool IsSharedInitializer(std::string_view initializer_name) { - return initializer_name.rfind(SHARED_INITIALIZER_PREFIX, 0) == 0; -} - -// Return true when initializer node arg is consumed by any node conaining sub graphs; +// Return true when initializer node arg is consumed by any node containing sub graphs; // Otherwise, return false. bool PrepareInputPortsToReplace(Graph& graph, const NodeArg* origin_initializer_node_arg, InlinedHashMap>& consumer_node_to_input_ports_map) { @@ -57,7 +79,7 @@ bool PrepareInputPortsToReplace(Graph& graph, const NodeArg* origin_initializer_ } // Iterate all input defs to replace those that are equal to origin_initializer_node_arg, - // Then it would be safe to remove the consumer node aferwards. + // Then it would be safe to remove the consumer node afterwards. for (int i = 0; i < static_cast(const_node->InputDefs().size()); ++i) { if (const_node->InputDefs()[i] == origin_initializer_node_arg) { consumer_node_to_input_ports_map[const_node].push_back(i); @@ -98,40 +120,64 @@ void ReplaceInputsToUseSharedInitializer(Graph& graph, } /** - * @brief Get value unique id from constant store. + * @brief Initializer value representation, which is used to store and compare initializer values. * - * @tparam T Type of value to parse value from initializer. + * Two instances of InitializerValue are equal when: + * 1. data type match. + * 2. data rank match. + * 3. shape match. + * 4. value exactly match. + */ +struct InitializerValue { + InitializerValue(const ONNX_NAMESPACE::TensorProto* tensor_proto, Graph& graph) + : initializer{*tensor_proto, graph.ModelPath()} { + } + + bool operator==(const InitializerValue& other) const { + if (initializer.data_type() == other.initializer.data_type() && // data type + initializer.dims().size() == other.initializer.dims().size() && // rank + SpanEq(initializer.dims(), other.initializer.dims())) { // shape + return SpanEq(initializer.DataAsByteSpan(), other.initializer.DataAsByteSpan()); + } + + return false; + } + + bool operator!=(const InitializerValue& other) const { + return !(*this == other); + } + + Initializer initializer; +}; + +/** + * @brief Get value unique id from constant store. * * If the value parsed from initializer exists in constant store, then return the index in the container; * Otherwise, insert the value into container, return the last index. */ -template -struct GetOrAddValueInConstantStoreDispatcher { - size_t operator()(const onnxruntime::Initializer& initializer, - InlinedVector>& - const_value_store) const { - std::variant value; - if (std::is_same::value) { - value = math::halfToFloat(initializer.data()->val); - } else { - value = *initializer.data(); - } +size_t GetOrAddValueInConstantStore( + std::unique_ptr initializer, + InlinedHashMap>>& const_value_store, + const std::string& data_store_key) { + auto IsInitializerValueEqual = [&initializer](const std::unique_ptr& v) -> bool { + return *v == *initializer; + }; - auto it = std::find(const_value_store.begin(), const_value_store.end(), value); - if (it == const_value_store.end()) { - const_value_store.push_back(value); - return const_value_store.size() - 1; - } - return it - const_value_store.begin(); + auto& data_store = const_value_store[data_store_key]; + auto it = std::find_if(data_store.begin(), data_store.end(), IsInitializerValueEqual); + if (it == data_store.end()) { + data_store.emplace_back(std::move(initializer)); + return data_store.size() - 1; } -}; + return it - data_store.begin(); +} } // namespace Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const { int shared_count = 0; - // Accumulated map from type/value/rank to initializer: // > The key is a string representation of initializer's data type, value and rank. // > The value is newly created initializer NodeArg* to be shared. @@ -140,25 +186,25 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve InlinedVector original_initializer_names; original_initializer_names.reserve(initialized_tensor_set.size()); for (const auto& entry : initialized_tensor_set) { - // Ignore if the initializer exists in graph output, already handled, + // Ignore if the initializer exists in graph output, // or not a constant initializer (implicitly excludes the graph input). - if (IsSharedInitializer(entry.first) || - !graph_utils::IsConstantInitializer(graph, entry.first) || + if (!graph_utils::IsConstantInitializer(graph, entry.first) || graph.IsOutput(graph.GetNodeArg(entry.first)) || excluded_initializers_.find(entry.first) != excluded_initializers_.end()) { continue; } - original_initializer_names.push_back(entry.first); } // Avoid using the scalar value directly in pattern_key because the value for example INT_MAX can be super big // and it will be hard to read. Instead, a constant value store is maintained, then the value index is used as the // value unique id when construct pattern key. - InlinedVector> const_value_store; + InlinedHashMap>> const_value_store; for (const auto& initializer_name : original_initializer_names) { NodeArg* origin_initializer_node_arg = graph.GetNodeArg(initializer_name); - if (origin_initializer_node_arg == nullptr || !IsValidSingleValueShape(origin_initializer_node_arg->Shape())) { + int64_t num_elements = 1; + if (origin_initializer_node_arg == nullptr || + !IsAllowedToShare(origin_initializer_node_arg->Shape(), num_elements)) { continue; } @@ -168,7 +214,6 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve if (!tensor_proto || !IsSupportedDataType(tensor_proto->data_type())) { continue; } - // A map used to collect those consumers who have inputs use origin_initializer_node_arg. // > The key is consumer Node pointer. // > The value is a list of indices for the consumer Nodes' input (that used origin_initializer_node_arg). @@ -178,14 +223,18 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve if (found_subgraph_usage || consumer_node_to_input_ports_map.size() == 0) { continue; } + const std::string data_store_key = MakeString(tensor_proto->data_type(), + "_", origin_initializer_node_arg->Shape()->dim_size(), + "_", num_elements); - onnxruntime::Initializer initializer{*tensor_proto, graph.ModelPath()}; - utils::MLTypeCallDispatcherFromTypeList t_disp(tensor_proto->data_type()); - size_t value_id = t_disp.InvokeRet(initializer, const_value_store); + std::unique_ptr init_value = std::make_unique(tensor_proto, graph); + // The constant value store contains multiple buckets, indexed by data_store_key. + // For each initializer, we will check which bucket it belongs to, + // then add the value into the bucket if it does not exits; or get the index within the bucket if it already exists. + size_t value_id = GetOrAddValueInConstantStore(std::move(init_value), const_value_store, data_store_key); // Construct a string by data type, value, and rank. Used as a key in pattern_key_to_shared_arg_map. - const std::string pattern_key = MakeString(SHARED_INITIALIZER_PREFIX, value_id, "_", tensor_proto->data_type(), "_", - origin_initializer_node_arg->Shape()->dim_size()); + const std::string pattern_key = MakeString(SHARED_INITIALIZER_PREFIX, data_store_key, "_", value_id); // If there is no such existing scalar pattern, add a new one. if (pattern_key_to_shared_arg_map.find(pattern_key) == pattern_key_to_shared_arg_map.end()) { @@ -206,7 +255,6 @@ Status ConstantSharing::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve } LOGS(logger, INFO) << "Total shared scalar initializer count: " << shared_count; - return Status::OK(); } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index d1452eb47b..36b16ba27a 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -195,17 +195,23 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::move(rule_transformer)); } - // We need to remove the duplicated QDQ Pairs before all other GraphTransformation. - // no filtering on execution provider for L1 optimizations as they only use official ONNX operators + if (session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableDoubleQDQRemover, "0") == "0") { + // We need to remove the duplicated QDQ Pairs before all other GraphTransformation. + transformers.emplace_back(std::make_unique()); + } + // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by // default, CSE will not merge them, because the different initializers are represented by different NodeArg. - if (session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableDoubleQDQRemover, "0") == "0") { - transformers.emplace_back(std::make_unique()); + InlinedHashSet excluded_initializers; + excluded_initializers.reserve(session_options.initializers_to_share_map.size()); + for (const auto& p : session_options.initializers_to_share_map) { + excluded_initializers.insert(p.first); } - transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique(cpu_ep, excluded_initializers)); + transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq)); transformers.emplace_back(std::make_unique()); @@ -319,7 +325,7 @@ InlinedVector> GenerateTransformers( } #endif -#endif // !defined(DISABLE_CONTRIB_OPS) +#endif // !defined(DISABLE_CONTRIB_OPS) // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their // fusions might be prevented if this one removes a Q/DQ node too early. transformers.emplace_back(std::make_unique(enable_quant_qdq_cleanup)); diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 69d7a32090..c8e9779ff3 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -271,12 +271,24 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) { // so we have to assume that they are not deterministic, to be on the safe side. // We could also allow other known domains (kMSDomain, kMSNchwcDomain, kMSFeaturizersDomain), // as long as we verify which of their operations are non-deterministic and add them in the map below. -constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", "Multinomial"}; +constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNormal", "RandomUniformLike", + "RandomNormalLike", "Multinomial"}; + +#ifdef ENABLE_TRAINING_OPS +constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather"}; +#endif + bool IsOperationDeterministic(const std::string& domain, const std::string& op) { if (domain.compare(kOnnxDomain) == 0) { auto iter = std::find(kOnnxDomainNonDeterministicOps.begin(), kOnnxDomainNonDeterministicOps.end(), op); return iter == kOnnxDomainNonDeterministicOps.end(); } +#ifdef ENABLE_TRAINING_OPS + if (domain.compare(kMSDomain) == 0) { + auto iter = std::find(kMSDomainDeterministicOps.begin(), kMSDomainDeterministicOps.end(), op); + return iter != kMSDomainDeterministicOps.end(); + } +#endif // Unknown domain. Assume the op is not deterministic. return false; } diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 02d62dd41b..99460dd1de 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -505,7 +505,10 @@ def optimize_model(model_path: Path, opt_model_path: Path): sess_option = SessionOptions() sess_option.optimized_model_filepath = opt_model_path.as_posix() sess_option.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC - _ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"]) + kwargs = {} + # This will rename constant initializer names, disable it to make test pass. + kwargs["disabled_optimizers"] = ["ConstantSharing"] + _ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"], **kwargs) def add_pre_process_metadata(model): diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 704c4082fb..3083c0ccfd 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -716,11 +716,11 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { ASSERT_TRUE(lines[i].find(s) != string::npos); #ifdef USE_CUDA has_api_info = has_api_info || lines[i].find("Api") != string::npos && - lines[i].find("cudaLaunch") != string::npos; + lines[i].find("cudaLaunch") != string::npos; #endif #ifdef USE_ROCM has_api_info = has_api_info || lines[i].find("Api") != string::npos && - lines[i].find("hipLaunch") != string::npos; + lines[i].find("hipLaunch") != string::npos; #endif } } @@ -732,7 +732,6 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { #endif } - TEST(InferenceSessionTests, CheckRunProfilerWithStartProfile) { SessionOptions so; @@ -877,12 +876,12 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { ASSERT_TRUE(have_log_entry_with_vlog_session_msg); - //bool have_log_entry_with_vlog_run_msg = - // (std::find_if(msgs.begin(), msgs.end(), - // [&](std::string msg) { return msg.find("Size of execution plan vector") != string::npos; }) != - // msgs.end()); + // bool have_log_entry_with_vlog_run_msg = + // (std::find_if(msgs.begin(), msgs.end(), + // [&](std::string msg) { return msg.find("Size of execution plan vector") != string::npos; }) != + // msgs.end()); - //ASSERT_TRUE(have_log_entry_with_vlog_run_msg); + // ASSERT_TRUE(have_log_entry_with_vlog_run_msg); bool has_num_streams_msg = (std::find_if(msgs.begin(), msgs.end(), [&](std::string msg) { return msg.find("Number of streams") != string::npos; }) != msgs.end()); @@ -2778,15 +2777,19 @@ TEST(InferenceSessionTests, InitializerSharing_EnsureSessionsUseUserAddedInitial ASSERT_EQ(so1_init_buffer, so2_init_buffer); int so3_idx; - ASSERT_STATUS_OK(sess3.GetSessionState().GetOrtValueNameIdxMap().GetIdx(init_name, so3_idx)); - const auto* so3_init_buffer = sess3.GetSessionState().GetInitializedTensors().at(so3_idx).Get().Data(); + // If the original initializer name got changed by graph transformers, then we don't need check + // the data ptr reuse or not with other session. + if (sess3.GetSessionState().GetOrtValueNameIdxMap().GetIdx(init_name, so3_idx).IsOK()) { + const auto* so3_init_buffer = + sess3.GetSessionState().GetInitializedTensors().at(so3_idx).Get().Data(); - // Ensure session 3 doesn't share the same data ptr as any other session - ASSERT_NE(so3_init_buffer, so1_init_buffer); - ASSERT_NE(so3_init_buffer, so2_init_buffer); + // Ensure session 3 doesn't share the same data ptr as any other session + ASSERT_NE(so3_init_buffer, so1_init_buffer); + ASSERT_NE(so3_init_buffer, so2_init_buffer); - // Ensure session 3 doesn't share the same data ptr as the one supplied by the user for any of the other sessions - ASSERT_NE(so3_init_buffer, val_to_share.Get().Data()); + // Ensure session 3 doesn't share the same data ptr as the one supplied by the user for any of the other sessions + ASSERT_NE(so3_init_buffer, val_to_share.Get().Data()); + } } void RunModelWithDenormalAsZero(InferenceSession& session_object, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 0c4a685371..96ce3d5491 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -6202,6 +6202,130 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatOrHalfTypedInitialize } } +template +void BuildConstantSharingDivMulGraphFor2DInitializer(ModelTestBuilder& builder) { + auto* input0_arg = builder.MakeInput({{1, 1, 256, 8}}); + auto* input1_arg = builder.MakeInput({{1, 1, 256, 8}}); + auto* div_out = builder.MakeIntermediate(); + builder.AddNode("Div", {input0_arg, input1_arg}, {div_out}); + + std::vector values{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + std::vector values_float16; + values_float16.reserve(values.size()); + if (std::is_same::value) { + for (auto v : values) { + values_float16.push_back(MLFloat16(math::floatToHalf(v))); + } + } + + for (size_t i = 0; i < 12; ++i) { + NodeArg* mul_initializer = nullptr; + if (std::is_same::value) { + mul_initializer = builder.MakeInitializer({1, 8}, values_float16); + } else if (std::is_same::value) { + mul_initializer = builder.MakeInitializer({1, 8}, values); + } else { + ASSERT_TRUE(false); + } + auto* mul_out = builder.MakeOutput(); + builder.AddNode("Mul", {div_out, mul_initializer}, {mul_out}); + } +} + +/* +Test graph include multiple equivalent subgraphs as below. + graph input [1, 1, 256, 8] (float|MLFloat16) + | + Div + / | \ + / | \ + / ... | / ... \ + Mul Mul Mul + | | | + graph out [1, 1, 256, 8] (float|MLFloat16) + +Be noted: + the Mul's input initializer is a 2D float/MLFloat16. +*/ +TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatOrHalfTypedInitializer) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 2U); + TEST_RETURN_IF_NOT(op_count_pre["Div"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Mul"] == 12); + TEST_RETURN_IF_NOT(graph.GetAllInitializedTensors().size() == 12U); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + const InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors(); + TEST_RETURN_IF_NOT(initialized_tensor_set.size() == 1U); + const NodeArg* mul_initializer = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Mul") == 0) { + if (!mul_initializer) { + mul_initializer = node.InputDefs()[1]; + TEST_RETURN_IF(mul_initializer == nullptr); + TEST_RETURN_IF_NOT(mul_initializer->Shape()->dim_size() == 2); + } else { + TEST_RETURN_IF_NOT(mul_initializer == node.InputDefs()[1]); + } + } + } + TEST_RETURN_IF(mul_initializer == nullptr); + for (const auto& entry : initialized_tensor_set) { + if (entry.first.compare(mul_initializer->Name()) == 0) { + const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; + int32_t data_type = tensor_proto->data_type(); + onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(float_const.size() == 8); + for (int i = 0; i < 8; ++i) { + float float_const_value; + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + float_const_value = math::halfToFloat((float_const.data() + i)->val); + } else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + float_const_value = *(float_const.data() + i); + } else { + return Status(common::ONNXRUNTIME, common::FAIL, "unexpected type"); + } + TEST_RETURN_IF_NOT(float_const_value == i * 1.0f); + } + } + } + + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count.size() == 2U); + TEST_RETURN_IF_NOT(op_count["Div"] == 1); + TEST_RETURN_IF_NOT(op_count["Mul"] == 12); + return Status::OK(); + }; + + const std::vector opsets{12, 13, 14}; // Clip support int64_t since opset 12 + + // Float data type tests. + auto build_test_case_float = [&](ModelTestBuilder& builder) { + BuildConstantSharingDivMulGraphFor2DInitializer(builder); + }; + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case_float, opset_version, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + + // MLFloat16 data type tests. + auto build_test_case_mlfloat16 = [&](ModelTestBuilder& builder) { + BuildConstantSharingDivMulGraphFor2DInitializer(builder); + }; + + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case_mlfloat16, opset_version, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + /* Test graph include multiple equivalent subgraphs as below. graph input [1, 1, 256, 256] (float) @@ -6314,6 +6438,158 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatAndHalfTypedInitializ } } +/* +Test graph include multiple equivalent subgraphs as below. + graph input [1, 1, 8, 8] (float) + | + Div ______________________________ + / | \_______ | | + / | float | | | half | half + / ... | / ... | | | / ... | / ... + Mul Mul Sub Sub Add Add + | | | | \ / + graph out [1, 1, 8, 8](float) graph out [1, 1, 8, 8](MLFloat16) + +Be noted: + the Mul's input initializer is a 2D float tensor. + the Add's input initializer is a 2D MLFloat16 tensor. +*/ +TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatAndHalfTypedInitializer) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 5U); + TEST_RETURN_IF_NOT(op_count_pre["Div"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Cast"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Mul"] == 3); + TEST_RETURN_IF_NOT(op_count_pre["Sub"] == 3); + TEST_RETURN_IF_NOT(op_count_pre["Add"] == 3); + TEST_RETURN_IF_NOT(graph.GetAllInitializedTensors().size() == 9U); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + const InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors(); + TEST_RETURN_IF_NOT(initialized_tensor_set.size() == 3U); + const NodeArg* mul_initializer = nullptr; + const NodeArg* sub_initializer = nullptr; + const NodeArg* add_initializer = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType().compare("Mul") == 0) { + if (!mul_initializer) { + mul_initializer = node.InputDefs()[1]; + TEST_RETURN_IF(mul_initializer == nullptr); + TEST_RETURN_IF_NOT(mul_initializer->Shape()->dim_size() == 2); + TEST_RETURN_IF_NOT(mul_initializer->Shape()->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(mul_initializer->Shape()->dim(1).dim_value() == 8); + } else { + TEST_RETURN_IF_NOT(mul_initializer == node.InputDefs()[1]); + } + } else if (node.OpType().compare("Sub") == 0) { + if (!sub_initializer) { + sub_initializer = node.InputDefs()[1]; + TEST_RETURN_IF(sub_initializer == nullptr); + TEST_RETURN_IF_NOT(sub_initializer->Shape()->dim_size() == 2); + TEST_RETURN_IF_NOT(sub_initializer->Shape()->dim(0).dim_value() == 8); + TEST_RETURN_IF_NOT(sub_initializer->Shape()->dim(1).dim_value() == 1); + } else { + TEST_RETURN_IF_NOT(sub_initializer == node.InputDefs()[1]); + } + } else if (node.OpType().compare("Add") == 0) { + if (!add_initializer) { + add_initializer = node.InputDefs()[1]; + TEST_RETURN_IF(add_initializer == nullptr); + TEST_RETURN_IF_NOT(add_initializer->Shape()->dim_size() == 2); + TEST_RETURN_IF_NOT(add_initializer->Shape()->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(add_initializer->Shape()->dim(1).dim_value() == 8); + } else { + TEST_RETURN_IF_NOT(add_initializer == node.InputDefs()[1]); + } + } + } + TEST_RETURN_IF(mul_initializer == nullptr); + TEST_RETURN_IF(sub_initializer == nullptr); + TEST_RETURN_IF(add_initializer == nullptr); + for (const auto& entry : initialized_tensor_set) { + const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; + int32_t data_type = tensor_proto->data_type(); + onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(float_const.size() == 8); + if (entry.first.compare(mul_initializer->Name()) == 0) { + TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + for (int i = 0; i < 8; ++i) { + float float_const_value = *(float_const.data() + i); + TEST_RETURN_IF_NOT(float_const_value == i * 1.0f); + } + } else if (entry.first.compare(sub_initializer->Name()) == 0) { + TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + for (int i = 0; i < 8; ++i) { + float float_const_value = *(float_const.data() + i); + TEST_RETURN_IF_NOT(float_const_value == i * 1.0f); + } + } else if (entry.first.compare(add_initializer->Name()) == 0) { + TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + for (int i = 0; i < 8; ++i) { + float float_const_value = math::halfToFloat((float_const.data() + i)->val); + TEST_RETURN_IF_NOT(float_const_value == i * 1.0f); + } + } + } + + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count.size() == 5U); + TEST_RETURN_IF_NOT(op_count["Div"] == 1); + TEST_RETURN_IF_NOT(op_count["Mul"] == 3); + TEST_RETURN_IF_NOT(op_count["Sub"] == 3); + TEST_RETURN_IF_NOT(op_count["Cast"] == 1); + TEST_RETURN_IF_NOT(op_count["Add"] == 3); + return Status::OK(); + }; + + const std::vector opsets{12, 13, 14}; + + std::vector values{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + std::vector values_float16; + values_float16.reserve(values.size()); + for (auto v : values) { + values_float16.push_back(MLFloat16(math::floatToHalf(v))); + } + + auto build_test_case_float = [&values, &values_float16](ModelTestBuilder& builder) { + auto* input0_arg = builder.MakeInput({{1, 1, 8, 8}}); + auto* input1_arg = builder.MakeInput({{1, 1, 8, 8}}); + auto* div_out = builder.MakeIntermediate(); + builder.AddNode("Div", {input0_arg, input1_arg}, {div_out}); + + for (size_t i = 0; i < 3; ++i) { + NodeArg* mul_initializer = builder.MakeInitializer({1, 8}, values); + auto* mul_out = builder.MakeOutput(); + builder.AddNode("Mul", {div_out, mul_initializer}, {mul_out}); + } + + for (size_t i = 0; i < 3; ++i) { + NodeArg* sub_initializer = builder.MakeInitializer({8, 1}, values); + auto* sub_out = builder.MakeOutput(); + builder.AddNode("Sub", {div_out, sub_initializer}, {sub_out}); + } + + auto* cast_out = builder.MakeIntermediate(); + builder.AddNode("Cast", {div_out}, {cast_out}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + for (size_t i = 0; i < 3; ++i) { + NodeArg* add_initializer = builder.MakeInitializer({1, 8}, values_float16); + auto* add_out = builder.MakeOutput(); + builder.AddNode("Add", {cast_out, add_initializer}, {add_out}); + } + }; + + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case_float, opset_version, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + /* Test graph include multiple equivalent subgraphs as below. graph input [1, 1, 256, 256] (float)