diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 63612c47f9..b67670a0e9 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -45,6 +45,7 @@ #include "core/optimizer/identical_children_consolidation.h" #include "core/optimizer/identity_elimination.h" #include "core/optimizer/layer_norm_fusion.h" +#include "core/optimizer/label_encoder_fusion.h" #include "core/optimizer/matmul_activation_fusion.h" #include "core/optimizer/matmul_add_fusion.h" #include "core/optimizer/matmul_integer_to_float.h" @@ -133,6 +134,7 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); break; case TransformerLevel::Level2: diff --git a/onnxruntime/core/optimizer/label_encoder_fusion.cc b/onnxruntime/core/optimizer/label_encoder_fusion.cc new file mode 100644 index 0000000000..043cd31b88 --- /dev/null +++ b/onnxruntime/core/optimizer/label_encoder_fusion.cc @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/optimizer/label_encoder_fusion.h" +#include "core/framework/op_node_proto_helper.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +#define KEYS_ATTR_NAME(T) ("keys_" + GetTypename() + "s") +#define VALUES_ATTR_NAME(T) ("values_" + GetTypename() + "s") +#define DEFAULT_VALUE_ATTR_NAME(T) ("default_" + GetTypename()) + +// May be needed somewhere else +// Think about moving into utils +template +[[maybe_unused]] constexpr bool false_for_T = false; + +template +std::string GetTypename() { + if constexpr (std::is_same()) { + return "int64"; + } else if constexpr (std::is_same()) { + return "string"; + } else if constexpr (std::is_same()) { + return "float"; + } else { + static_assert(false_for_T, "Unsupported type"); + } +} + +template +bool LabelEncoderFusion::IsValidForFusion(const Node& node, const Node& next_node) const { + return (node.GetAttributes().find(KEYS_ATTR_NAME(T1)) != node.GetAttributes().end() && + node.GetAttributes().find(VALUES_ATTR_NAME(T2)) != node.GetAttributes().end() && + next_node.GetAttributes().find(KEYS_ATTR_NAME(T2)) != next_node.GetAttributes().end() && + next_node.GetAttributes().find(VALUES_ATTR_NAME(T3)) != next_node.GetAttributes().end()); +} + +/** +Transform that fuses two consecutive LabelEncoder nodes +into one LabelEncoder node. + */ +bool LabelEncoderFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain( + node, "LabelEncoder", {2, 4}, "ai.onnx.ml") || + node.GetOutputEdgesCount() != 1) { + return false; + } + + const auto& next_node = *node.OutputNodesBegin(); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LabelEncoder", {4}, "ai.onnx.ml") || + // Make sure the two nodes do not span execution providers. + next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + // Is one of the supported operations + return IsValidForFusion(node, next_node) || + IsValidForFusion(node, next_node) || + IsValidForFusion(node, next_node) || + IsValidForFusion(node, next_node) || + IsValidForFusion(node, next_node) || + IsValidForFusion(node, next_node) || + IsValidForFusion(node, next_node) || + IsValidForFusion(node, next_node); +} + +/** +Since we need to be polymorphic on the datatype +we will dispatch to this method from the main Apply +*/ +template +Status LabelEncoderFusion::ApplyHelper( + Graph& graph, + Node& node, + Node& next_node, + RewriteRuleEffect& rule_effect) const { + ProtoHelperNodeContext node_helper_ctx(node); + OpNodeProtoHelper node_helper(&node_helper_ctx); + + ProtoHelperNodeContext next_node_helper_ctx(next_node); + OpNodeProtoHelper next_node_helper(&next_node_helper_ctx); + + const std::vector node_keys = + node_helper.GetAttrsOrDefault(KEYS_ATTR_NAME(T1)); + const std::vector node_values = + node_helper.GetAttrsOrDefault(VALUES_ATTR_NAME(T2)); + const T2 node_default = + node_helper.GetAttr(DEFAULT_VALUE_ATTR_NAME(T2)); + + const std::vector next_node_keys = + next_node_helper.GetAttrsOrDefault(KEYS_ATTR_NAME(T2)); + const std::vector next_node_values = + next_node_helper.GetAttrsOrDefault(VALUES_ATTR_NAME(T3)); + const T3 next_node_default = + next_node_helper.GetAttr(DEFAULT_VALUE_ATTR_NAME(T3)); + + const auto getFromMapDefault = [](const auto& mp, const auto key, const auto def) { + return (mp.find(key) == mp.end()) ? def : mp.at(key); + }; + + // Perform value propagation through the second label encoder + std::unordered_map mapping = {}; + for (size_t i = 0; i < next_node_keys.size(); i++) { + mapping[next_node_keys[i]] = next_node_values[i]; + } + + std::vector new_node_values = {}; + const auto new_node_default = getFromMapDefault(mapping, node_default, next_node_default); + + for (const T2& node_value : node_values) { + new_node_values.push_back(getFromMapDefault(mapping, node_value, next_node_default)); + } + + // Remove old attributes: + // The keys attribute is correct, we just reroute + // the values + node.ClearAttribute(VALUES_ATTR_NAME(T2)); + node.ClearAttribute(DEFAULT_VALUE_ATTR_NAME(T2)); + + node.AddAttribute(VALUES_ATTR_NAME(T3), new_node_values); + node.AddAttribute(DEFAULT_VALUE_ATTR_NAME(T3), new_node_default); + + graph_utils::FinalizeNodeFusion(graph, node, next_node); + + rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; + + return Status::OK(); +} + +#define FUSE_IF_VALID(T1, T2, T3) \ + if (IsValidForFusion(node, next_node)) { \ + return ApplyHelper( \ + graph, node, next_node, rule_effect); \ + } + +Status LabelEncoderFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const { + auto& next_node = *graph.GetNode(node.OutputNodesBegin()->Index()); + + FUSE_IF_VALID(std::string, std::string, std::string); + FUSE_IF_VALID(std::string, std::string, int64_t); + FUSE_IF_VALID(std::string, int64_t, std::string); + FUSE_IF_VALID(std::string, int64_t, int64_t); + FUSE_IF_VALID(int64_t, std::string, std::string); + FUSE_IF_VALID(int64_t, std::string, int64_t); + FUSE_IF_VALID(int64_t, int64_t, std::string); + FUSE_IF_VALID(int64_t, int64_t, int64_t); + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/label_encoder_fusion.h b/onnxruntime/core/optimizer/label_encoder_fusion.h new file mode 100644 index 0000000000..30d69f0dcf --- /dev/null +++ b/onnxruntime/core/optimizer/label_encoder_fusion.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/** +@Class LabelEncoderFusion + +Rewrite rule that fuses two LabelEncoder -> LabelEncoder nodes to a single +LabelEncoder node. + +*/ +class LabelEncoderFusion : public RewriteRule { + public: + LabelEncoderFusion() noexcept : RewriteRule("LabelEncoderFusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"LabelEncoder"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; + + template + Status ApplyHelper(Graph& graph, Node& node, Node& next_node, RewriteRuleEffect& rule_effect) const; + + template + bool IsValidForFusion(const Node& node, const Node& next) const; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 0d1f213618..d5e6ed7c10 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -69,6 +69,7 @@ #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" #include "core/optimizer/utils.h" +#include "core/optimizer/label_encoder_fusion.h" #include "core/platform/env.h" #include "core/session/inference_session.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -1901,6 +1902,68 @@ TEST_F(GraphTransformationTests, DivMulFusion) { ASSERT_TRUE(op_to_count["Mul"] == 2); } +TEST_F(GraphTransformationTests, LabelEncoderFusion) { + using common::INVALID_GRAPH; + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/label_encoder.onnx"; + + NameMLValMap feeds; + + constexpr size_t ALPH = 26; + OrtValue mlvalue_a; + std::vector dims_a = {ALPH}; + std::vector values_a = {}; + for (char letter = 'a'; letter <= 'z'; letter++) { + values_a.emplace_back(1, letter); + } + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_a, + values_a, &mlvalue_a); + feeds.insert(std::make_pair("A", mlvalue_a)); + + bool is_implemented = true; + + auto run_model_test = [&](TransformerLevel level, std::vector& fetches, const int requiredLabelEncoderCount) { + SessionOptions session_options; + session_options.graph_optimization_level = level; + session_options.session_logid = "OptimizerTests"; + InferenceSessionWrapper session{session_options, GetEnvironment()}; + + // If we did not initialize the session correctly, the operator is missing. + if (!session.Load(model_uri).IsOK() || !session.Initialize().IsOK()) { + is_implemented = false; + return; + } + + // Count if the number of LabelEncoders is as expected + std::map op_to_count = CountOpsInGraph(session.GetGraph()); + ASSERT_TRUE(op_to_count["ai.onnx.ml.LabelEncoder"] == requiredLabelEncoderCount); + + std::vector output_names = {}; + for (const auto& output : session.GetGraph().GetOutputs()) { + output_names.push_back(output->Name()); + } + + RunOptions run_options; + ASSERT_STATUS_OK(session.Run(run_options, feeds, output_names, &fetches)); + }; + + // run model with and w/o optimizations and compare the results + std::vector unoptimized_fetches; + run_model_test(TransformerLevel::Default, unoptimized_fetches, 11); + + std::vector optimized_fetches; + run_model_test(TransformerLevel::MaxLevel, optimized_fetches, 7); + + // If there was a problem loading the model, do not compare the 2 results + if (!is_implemented) { + GTEST_SKIP(); + return; + } + + // Compare results + auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], 0.0, 0.0, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; +} + TEST_F(GraphTransformationTests, NotWhereFusion) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/not_where.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/fusion/label_encoder.onnx b/onnxruntime/test/testdata/transform/fusion/label_encoder.onnx new file mode 100644 index 0000000000..ffc28ed256 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/label_encoder.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/label_encoder.py b/onnxruntime/test/testdata/transform/fusion/label_encoder.py new file mode 100644 index 0000000000..96c203547a --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/label_encoder.py @@ -0,0 +1,163 @@ +from onnx import OperatorSetIdProto, TensorProto, helper, save + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 19 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = "com.microsoft" +opsets.append(msdomain) + +ai_ml_domain = OperatorSetIdProto() +ai_ml_domain.version = 4 +ai_ml_domain.domain = "ai.onnx.ml" +opsets.append(ai_ml_domain) + +kwargs = {} +kwargs["opset_imports"] = opsets + + +def generate_model(model_name): + # Create models with consecutive label encoders + nodes = [ # subgraph + # string -> int -> string + helper.make_node( + "LabelEncoder", + ["A"], + ["le_1_int_1"], + "le_1_int_1", + domain="ai.onnx.ml", + keys_strings=["a", "b", "c"], + values_int64s=[0, 1, 2], + ), + helper.make_node( + "LabelEncoder", + ["le_1_int_1"], + ["le_1_string_2"], + "le_1_string_2", + domain="ai.onnx.ml", + keys_int64s=[2, 1, 0], + values_strings=["a", "b", "c"], + default_string="default", + ), + # string -> string -> string + helper.make_node( + "LabelEncoder", + ["A"], + ["le_2_string_1"], + "le_2_string_1", + domain="ai.onnx.ml", + keys_strings=["a", "b", "c"], + values_strings=["C", "B", "A"], + default_string="D", + ), + helper.make_node( + "LabelEncoder", + ["le_2_string_1"], + ["le_2_string_2"], + "le_2_string_2", + domain="ai.onnx.ml", + keys_strings=["A", "B", "C", "D"], + values_strings=["a", "b", "c", "d"], + default_string="default", + ), + # string -> string -> int -> string + helper.make_node( + "LabelEncoder", + ["A"], + ["le_3_string_1"], + "le_3_string_1", + domain="ai.onnx.ml", + keys_strings=["a", "b", "c"], + values_strings=["C", "B", "A"], + ), + helper.make_node( + "LabelEncoder", + ["le_3_string_1"], + ["le_3_int_2"], + "le_3_int_2", + domain="ai.onnx.ml", + keys_strings=["A", "B", "C"], + values_int64s=[1, 2, 3], + default_int64=-1, + ), + helper.make_node( + "LabelEncoder", + ["le_3_int_2"], + ["le_3_string_3"], + "le_3_string_3", + domain="ai.onnx.ml", + keys_int64s=[1, 2, 3], + values_strings=["a", "b", "c"], + default_string="d", + ), + # middle encoder is graph output + helper.make_node( + "LabelEncoder", + ["A"], + ["le_4_int_1"], + "le_4_int_1", + domain="ai.onnx.ml", + keys_strings=["a", "b", "c"], + values_int64s=[0, 1, 2], + ), + helper.make_node( + "LabelEncoder", + ["le_4_int_1"], + ["le_4_string_2"], + "le_4_string_2", + domain="ai.onnx.ml", + keys_int64s=[0, 1, 2], + values_strings=["a", "b", "c"], + ), + helper.make_node("Identity", ["le_4_int_1"], ["Y"], "output"), + # middle encoder is consumed twice + helper.make_node( + "LabelEncoder", + ["A"], + ["le_5_int_1"], + "le_5_int_1", + domain="ai.onnx.ml", + keys_strings=["a", "b", "c"], + values_int64s=[0, 1, 2], + ), + helper.make_node( + "LabelEncoder", + ["le_5_int_1"], + ["le_5_string_2"], + "le_5_string_2", + domain="ai.onnx.ml", + keys_int64s=[0, 1, 2], + values_strings=["a", "b", "c"], + ), + helper.make_node("Mul", ["le_5_int_1", "le_5_int_1"], ["mul_5"], "mul_5"), + ] + + inputs = [ # inputs + helper.make_tensor_value_info("A", TensorProto.STRING, ["N"]), + ] + + graph = helper.make_graph( + nodes, + "LabelEncoder", # name + inputs, + [ # outputs + helper.make_tensor_value_info("le_1_string_2", TensorProto.STRING, ["N"]), + helper.make_tensor_value_info("le_2_string_2", TensorProto.STRING, ["N"]), + helper.make_tensor_value_info("le_3_string_3", TensorProto.STRING, ["N"]), + helper.make_tensor_value_info("le_4_string_2", TensorProto.STRING, ["N"]), + helper.make_tensor_value_info("Y", TensorProto.INT64, ["N"]), + helper.make_tensor_value_info("mul_5", TensorProto.INT64, ["N"]), + ], + [], + ) + + model = helper.make_model(graph, **kwargs) + save(model, model_name) + + +if __name__ == "__main__": + generate_model("label_encoder.onnx")