Label encoder fusion (#19761)

### Description
Created a new `LabelEncoderFusion` pass. This is useful in model that
result from automatic conversion tools related to data-science:
sometimes the produced model contains consecutive `LabelEncoder`-s.
To merge 2 `LabelEncoder`-s the optimizer propagates the outputs of the
first encoder through the second one.


### Motivation and Context
This enhances the capabilities of the `onnxruntime::optimizer` by fusing
consecutive `LabelEncoder` nodes.


### Fusion examples
```
Applying fusion
node1: (a,C) (b,B) (c,A) -> Default: _Unused
node2: (A,1) (B,2) (C,3) -> Default: -1
fused: (a,3) (b,2) (c,1) -> Default: -1
Applying fusion
node1: (a,C) (b,B) (c,A) -> Default: D
node2: (A,a) (B,b) (C,c) (D,d) -> Default: default
fused: (a,c) (b,b) (c,a) -> Default: d
Applying fusion
node1: (a,0) (b,1) (c,2) -> Default: -1
node2: (2,a) (1,b) (0,c) -> Default: default
fused: (a,c) (b,b) (c,a) -> Default: default
Applying fusion
node1: (a,3) (b,2) (c,1) -> Default: -1
node2: (1,a) (2,b) (3,c) -> Default: d
fused: (a,c) (b,b) (c,a) -> Default: d
```

---------

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
Atanas Dimitrov 2024-04-01 17:41:10 +01:00 committed by GitHub
parent 523ef04240
commit 9d06e1bfa4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 425 additions and 0 deletions

View file

@ -45,6 +45,7 @@
#include "core/optimizer/identical_children_consolidation.h"
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/label_encoder_fusion.h"
#include "core/optimizer/matmul_activation_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/matmul_integer_to_float.h"
@ -133,6 +134,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
rules.push_back(std::make_unique<MatmulBNFusion>());
rules.push_back(std::make_unique<ClipQuantFusion>());
rules.push_back(std::make_unique<ReluQuantFusion>());
rules.push_back(std::make_unique<LabelEncoderFusion>());
break;
case TransformerLevel::Level2:

View file

@ -0,0 +1,162 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <unordered_map>
#include <vector>
#include <string>
#include "core/optimizer/label_encoder_fusion.h"
#include "core/framework/op_node_proto_helper.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"
namespace onnxruntime {
#define KEYS_ATTR_NAME(T) ("keys_" + GetTypename<T>() + "s")
#define VALUES_ATTR_NAME(T) ("values_" + GetTypename<T>() + "s")
#define DEFAULT_VALUE_ATTR_NAME(T) ("default_" + GetTypename<T>())
// May be needed somewhere else
// Think about moving into utils
template <typename>
[[maybe_unused]] constexpr bool false_for_T = false;
template <typename T>
std::string GetTypename() {
if constexpr (std::is_same<T, int64_t>()) {
return "int64";
} else if constexpr (std::is_same<T, std::string>()) {
return "string";
} else if constexpr (std::is_same<T, float>()) {
return "float";
} else {
static_assert(false_for_T<T>, "Unsupported type");
}
}
template <typename T1, typename T2, typename T3>
bool LabelEncoderFusion::IsValidForFusion(const Node& node, const Node& next_node) const {
return (node.GetAttributes().find(KEYS_ATTR_NAME(T1)) != node.GetAttributes().end() &&
node.GetAttributes().find(VALUES_ATTR_NAME(T2)) != node.GetAttributes().end() &&
next_node.GetAttributes().find(KEYS_ATTR_NAME(T2)) != next_node.GetAttributes().end() &&
next_node.GetAttributes().find(VALUES_ATTR_NAME(T3)) != next_node.GetAttributes().end());
}
/**
Transform that fuses two consecutive LabelEncoder nodes
into one LabelEncoder node.
*/
bool LabelEncoderFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(
node, "LabelEncoder", {2, 4}, "ai.onnx.ml") ||
node.GetOutputEdgesCount() != 1) {
return false;
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LabelEncoder", {4}, "ai.onnx.ml") ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}
if (graph.NodeProducesGraphOutput(node)) {
return false;
}
// Is one of the supported operations
return IsValidForFusion<std::string, std::string, std::string>(node, next_node) ||
IsValidForFusion<std::string, std::string, int64_t>(node, next_node) ||
IsValidForFusion<std::string, int64_t, std::string>(node, next_node) ||
IsValidForFusion<std::string, int64_t, int64_t>(node, next_node) ||
IsValidForFusion<int64_t, std::string, std::string>(node, next_node) ||
IsValidForFusion<int64_t, std::string, int64_t>(node, next_node) ||
IsValidForFusion<int64_t, int64_t, std::string>(node, next_node) ||
IsValidForFusion<int64_t, int64_t, int64_t>(node, next_node);
}
/**
Since we need to be polymorphic on the datatype
we will dispatch to this method from the main Apply
*/
template <typename T1, typename T2, typename T3>
Status LabelEncoderFusion::ApplyHelper(
Graph& graph,
Node& node,
Node& next_node,
RewriteRuleEffect& rule_effect) const {
ProtoHelperNodeContext node_helper_ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> node_helper(&node_helper_ctx);
ProtoHelperNodeContext next_node_helper_ctx(next_node);
OpNodeProtoHelper<ProtoHelperNodeContext> next_node_helper(&next_node_helper_ctx);
const std::vector<T1> node_keys =
node_helper.GetAttrsOrDefault<T1>(KEYS_ATTR_NAME(T1));
const std::vector<T2> node_values =
node_helper.GetAttrsOrDefault<T2>(VALUES_ATTR_NAME(T2));
const T2 node_default =
node_helper.GetAttr<T2>(DEFAULT_VALUE_ATTR_NAME(T2));
const std::vector<T2> next_node_keys =
next_node_helper.GetAttrsOrDefault<T2>(KEYS_ATTR_NAME(T2));
const std::vector<T3> next_node_values =
next_node_helper.GetAttrsOrDefault<T3>(VALUES_ATTR_NAME(T3));
const T3 next_node_default =
next_node_helper.GetAttr<T3>(DEFAULT_VALUE_ATTR_NAME(T3));
const auto getFromMapDefault = [](const auto& mp, const auto key, const auto def) {
return (mp.find(key) == mp.end()) ? def : mp.at(key);
};
// Perform value propagation through the second label encoder
std::unordered_map<T2, T3> mapping = {};
for (size_t i = 0; i < next_node_keys.size(); i++) {
mapping[next_node_keys[i]] = next_node_values[i];
}
std::vector<T3> new_node_values = {};
const auto new_node_default = getFromMapDefault(mapping, node_default, next_node_default);
for (const T2& node_value : node_values) {
new_node_values.push_back(getFromMapDefault(mapping, node_value, next_node_default));
}
// Remove old attributes:
// The keys attribute is correct, we just reroute
// the values
node.ClearAttribute(VALUES_ATTR_NAME(T2));
node.ClearAttribute(DEFAULT_VALUE_ATTR_NAME(T2));
node.AddAttribute(VALUES_ATTR_NAME(T3), new_node_values);
node.AddAttribute(DEFAULT_VALUE_ATTR_NAME(T3), new_node_default);
graph_utils::FinalizeNodeFusion(graph, node, next_node);
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
return Status::OK();
}
#define FUSE_IF_VALID(T1, T2, T3) \
if (IsValidForFusion<T1, T2, T3>(node, next_node)) { \
return ApplyHelper<T1, T2, T3>( \
graph, node, next_node, rule_effect); \
}
Status LabelEncoderFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
auto& next_node = *graph.GetNode(node.OutputNodesBegin()->Index());
FUSE_IF_VALID(std::string, std::string, std::string);
FUSE_IF_VALID(std::string, std::string, int64_t);
FUSE_IF_VALID(std::string, int64_t, std::string);
FUSE_IF_VALID(std::string, int64_t, int64_t);
FUSE_IF_VALID(int64_t, std::string, std::string);
FUSE_IF_VALID(int64_t, std::string, int64_t);
FUSE_IF_VALID(int64_t, int64_t, std::string);
FUSE_IF_VALID(int64_t, int64_t, int64_t);
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class LabelEncoderFusion
Rewrite rule that fuses two LabelEncoder -> LabelEncoder nodes to a single
LabelEncoder node.
*/
class LabelEncoderFusion : public RewriteRule {
public:
LabelEncoderFusion() noexcept : RewriteRule("LabelEncoderFusion") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"LabelEncoder"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
template <typename T1, typename T2, typename T3>
Status ApplyHelper(Graph& graph, Node& node, Node& next_node, RewriteRuleEffect& rule_effect) const;
template <typename T1, typename T2, typename T3>
bool IsValidForFusion(const Node& node, const Node& next) const;
};
} // namespace onnxruntime

View file

@ -69,6 +69,7 @@
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/label_encoder_fusion.h"
#include "core/platform/env.h"
#include "core/session/inference_session.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
@ -1901,6 +1902,68 @@ TEST_F(GraphTransformationTests, DivMulFusion) {
ASSERT_TRUE(op_to_count["Mul"] == 2);
}
TEST_F(GraphTransformationTests, LabelEncoderFusion) {
using common::INVALID_GRAPH;
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/label_encoder.onnx";
NameMLValMap feeds;
constexpr size_t ALPH = 26;
OrtValue mlvalue_a;
std::vector<int64_t> dims_a = {ALPH};
std::vector<std::string> values_a = {};
for (char letter = 'a'; letter <= 'z'; letter++) {
values_a.emplace_back(1, letter);
}
CreateMLValue<std::string>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_a,
values_a, &mlvalue_a);
feeds.insert(std::make_pair("A", mlvalue_a));
bool is_implemented = true;
auto run_model_test = [&](TransformerLevel level, std::vector<OrtValue>& fetches, const int requiredLabelEncoderCount) {
SessionOptions session_options;
session_options.graph_optimization_level = level;
session_options.session_logid = "OptimizerTests";
InferenceSessionWrapper session{session_options, GetEnvironment()};
// If we did not initialize the session correctly, the operator is missing.
if (!session.Load(model_uri).IsOK() || !session.Initialize().IsOK()) {
is_implemented = false;
return;
}
// Count if the number of LabelEncoders is as expected
std::map<std::string, int> op_to_count = CountOpsInGraph(session.GetGraph());
ASSERT_TRUE(op_to_count["ai.onnx.ml.LabelEncoder"] == requiredLabelEncoderCount);
std::vector<std::string> output_names = {};
for (const auto& output : session.GetGraph().GetOutputs()) {
output_names.push_back(output->Name());
}
RunOptions run_options;
ASSERT_STATUS_OK(session.Run(run_options, feeds, output_names, &fetches));
};
// run model with and w/o optimizations and compare the results
std::vector<OrtValue> unoptimized_fetches;
run_model_test(TransformerLevel::Default, unoptimized_fetches, 11);
std::vector<OrtValue> optimized_fetches;
run_model_test(TransformerLevel::MaxLevel, optimized_fetches, 7);
// If there was a problem loading the model, do not compare the 2 results
if (!is_implemented) {
GTEST_SKIP();
return;
}
// Compare results
auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], 0.0, 0.0, false);
EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second;
}
TEST_F(GraphTransformationTests, NotWhereFusion) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/not_where.onnx";
std::shared_ptr<Model> model;

Binary file not shown.

View file

@ -0,0 +1,163 @@
from onnx import OperatorSetIdProto, TensorProto, helper, save
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 19
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets.append(msdomain)
ai_ml_domain = OperatorSetIdProto()
ai_ml_domain.version = 4
ai_ml_domain.domain = "ai.onnx.ml"
opsets.append(ai_ml_domain)
kwargs = {}
kwargs["opset_imports"] = opsets
def generate_model(model_name):
# Create models with consecutive label encoders
nodes = [ # subgraph
# string -> int -> string
helper.make_node(
"LabelEncoder",
["A"],
["le_1_int_1"],
"le_1_int_1",
domain="ai.onnx.ml",
keys_strings=["a", "b", "c"],
values_int64s=[0, 1, 2],
),
helper.make_node(
"LabelEncoder",
["le_1_int_1"],
["le_1_string_2"],
"le_1_string_2",
domain="ai.onnx.ml",
keys_int64s=[2, 1, 0],
values_strings=["a", "b", "c"],
default_string="default",
),
# string -> string -> string
helper.make_node(
"LabelEncoder",
["A"],
["le_2_string_1"],
"le_2_string_1",
domain="ai.onnx.ml",
keys_strings=["a", "b", "c"],
values_strings=["C", "B", "A"],
default_string="D",
),
helper.make_node(
"LabelEncoder",
["le_2_string_1"],
["le_2_string_2"],
"le_2_string_2",
domain="ai.onnx.ml",
keys_strings=["A", "B", "C", "D"],
values_strings=["a", "b", "c", "d"],
default_string="default",
),
# string -> string -> int -> string
helper.make_node(
"LabelEncoder",
["A"],
["le_3_string_1"],
"le_3_string_1",
domain="ai.onnx.ml",
keys_strings=["a", "b", "c"],
values_strings=["C", "B", "A"],
),
helper.make_node(
"LabelEncoder",
["le_3_string_1"],
["le_3_int_2"],
"le_3_int_2",
domain="ai.onnx.ml",
keys_strings=["A", "B", "C"],
values_int64s=[1, 2, 3],
default_int64=-1,
),
helper.make_node(
"LabelEncoder",
["le_3_int_2"],
["le_3_string_3"],
"le_3_string_3",
domain="ai.onnx.ml",
keys_int64s=[1, 2, 3],
values_strings=["a", "b", "c"],
default_string="d",
),
# middle encoder is graph output
helper.make_node(
"LabelEncoder",
["A"],
["le_4_int_1"],
"le_4_int_1",
domain="ai.onnx.ml",
keys_strings=["a", "b", "c"],
values_int64s=[0, 1, 2],
),
helper.make_node(
"LabelEncoder",
["le_4_int_1"],
["le_4_string_2"],
"le_4_string_2",
domain="ai.onnx.ml",
keys_int64s=[0, 1, 2],
values_strings=["a", "b", "c"],
),
helper.make_node("Identity", ["le_4_int_1"], ["Y"], "output"),
# middle encoder is consumed twice
helper.make_node(
"LabelEncoder",
["A"],
["le_5_int_1"],
"le_5_int_1",
domain="ai.onnx.ml",
keys_strings=["a", "b", "c"],
values_int64s=[0, 1, 2],
),
helper.make_node(
"LabelEncoder",
["le_5_int_1"],
["le_5_string_2"],
"le_5_string_2",
domain="ai.onnx.ml",
keys_int64s=[0, 1, 2],
values_strings=["a", "b", "c"],
),
helper.make_node("Mul", ["le_5_int_1", "le_5_int_1"], ["mul_5"], "mul_5"),
]
inputs = [ # inputs
helper.make_tensor_value_info("A", TensorProto.STRING, ["N"]),
]
graph = helper.make_graph(
nodes,
"LabelEncoder", # name
inputs,
[ # outputs
helper.make_tensor_value_info("le_1_string_2", TensorProto.STRING, ["N"]),
helper.make_tensor_value_info("le_2_string_2", TensorProto.STRING, ["N"]),
helper.make_tensor_value_info("le_3_string_3", TensorProto.STRING, ["N"]),
helper.make_tensor_value_info("le_4_string_2", TensorProto.STRING, ["N"]),
helper.make_tensor_value_info("Y", TensorProto.INT64, ["N"]),
helper.make_tensor_value_info("mul_5", TensorProto.INT64, ["N"]),
],
[],
)
model = helper.make_model(graph, **kwargs)
save(model, model_name)
if __name__ == "__main__":
generate_model("label_encoder.onnx")