mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Label encoder fusion (#19761)
### Description Created a new `LabelEncoderFusion` pass. This is useful in model that result from automatic conversion tools related to data-science: sometimes the produced model contains consecutive `LabelEncoder`-s. To merge 2 `LabelEncoder`-s the optimizer propagates the outputs of the first encoder through the second one. ### Motivation and Context This enhances the capabilities of the `onnxruntime::optimizer` by fusing consecutive `LabelEncoder` nodes. ### Fusion examples ``` Applying fusion node1: (a,C) (b,B) (c,A) -> Default: _Unused node2: (A,1) (B,2) (C,3) -> Default: -1 fused: (a,3) (b,2) (c,1) -> Default: -1 Applying fusion node1: (a,C) (b,B) (c,A) -> Default: D node2: (A,a) (B,b) (C,c) (D,d) -> Default: default fused: (a,c) (b,b) (c,a) -> Default: d Applying fusion node1: (a,0) (b,1) (c,2) -> Default: -1 node2: (2,a) (1,b) (0,c) -> Default: default fused: (a,c) (b,b) (c,a) -> Default: default Applying fusion node1: (a,3) (b,2) (c,1) -> Default: -1 node2: (1,a) (2,b) (3,c) -> Default: d fused: (a,c) (b,b) (c,a) -> Default: d ``` --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
parent
523ef04240
commit
9d06e1bfa4
6 changed files with 425 additions and 0 deletions
|
|
@ -45,6 +45,7 @@
|
|||
#include "core/optimizer/identical_children_consolidation.h"
|
||||
#include "core/optimizer/identity_elimination.h"
|
||||
#include "core/optimizer/layer_norm_fusion.h"
|
||||
#include "core/optimizer/label_encoder_fusion.h"
|
||||
#include "core/optimizer/matmul_activation_fusion.h"
|
||||
#include "core/optimizer/matmul_add_fusion.h"
|
||||
#include "core/optimizer/matmul_integer_to_float.h"
|
||||
|
|
@ -133,6 +134,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
|
|||
rules.push_back(std::make_unique<MatmulBNFusion>());
|
||||
rules.push_back(std::make_unique<ClipQuantFusion>());
|
||||
rules.push_back(std::make_unique<ReluQuantFusion>());
|
||||
rules.push_back(std::make_unique<LabelEncoderFusion>());
|
||||
break;
|
||||
|
||||
case TransformerLevel::Level2:
|
||||
|
|
|
|||
162
onnxruntime/core/optimizer/label_encoder_fusion.cc
Normal file
162
onnxruntime/core/optimizer/label_encoder_fusion.cc
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "core/optimizer/label_encoder_fusion.h"
|
||||
#include "core/framework/op_node_proto_helper.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
#define KEYS_ATTR_NAME(T) ("keys_" + GetTypename<T>() + "s")
|
||||
#define VALUES_ATTR_NAME(T) ("values_" + GetTypename<T>() + "s")
|
||||
#define DEFAULT_VALUE_ATTR_NAME(T) ("default_" + GetTypename<T>())
|
||||
|
||||
// May be needed somewhere else
|
||||
// Think about moving into utils
|
||||
template <typename>
|
||||
[[maybe_unused]] constexpr bool false_for_T = false;
|
||||
|
||||
template <typename T>
|
||||
std::string GetTypename() {
|
||||
if constexpr (std::is_same<T, int64_t>()) {
|
||||
return "int64";
|
||||
} else if constexpr (std::is_same<T, std::string>()) {
|
||||
return "string";
|
||||
} else if constexpr (std::is_same<T, float>()) {
|
||||
return "float";
|
||||
} else {
|
||||
static_assert(false_for_T<T>, "Unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
bool LabelEncoderFusion::IsValidForFusion(const Node& node, const Node& next_node) const {
|
||||
return (node.GetAttributes().find(KEYS_ATTR_NAME(T1)) != node.GetAttributes().end() &&
|
||||
node.GetAttributes().find(VALUES_ATTR_NAME(T2)) != node.GetAttributes().end() &&
|
||||
next_node.GetAttributes().find(KEYS_ATTR_NAME(T2)) != next_node.GetAttributes().end() &&
|
||||
next_node.GetAttributes().find(VALUES_ATTR_NAME(T3)) != next_node.GetAttributes().end());
|
||||
}
|
||||
|
||||
/**
|
||||
Transform that fuses two consecutive LabelEncoder nodes
|
||||
into one LabelEncoder node.
|
||||
*/
|
||||
bool LabelEncoderFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(
|
||||
node, "LabelEncoder", {2, 4}, "ai.onnx.ml") ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LabelEncoder", {4}, "ai.onnx.ml") ||
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (graph.NodeProducesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Is one of the supported operations
|
||||
return IsValidForFusion<std::string, std::string, std::string>(node, next_node) ||
|
||||
IsValidForFusion<std::string, std::string, int64_t>(node, next_node) ||
|
||||
IsValidForFusion<std::string, int64_t, std::string>(node, next_node) ||
|
||||
IsValidForFusion<std::string, int64_t, int64_t>(node, next_node) ||
|
||||
IsValidForFusion<int64_t, std::string, std::string>(node, next_node) ||
|
||||
IsValidForFusion<int64_t, std::string, int64_t>(node, next_node) ||
|
||||
IsValidForFusion<int64_t, int64_t, std::string>(node, next_node) ||
|
||||
IsValidForFusion<int64_t, int64_t, int64_t>(node, next_node);
|
||||
}
|
||||
|
||||
/**
|
||||
Since we need to be polymorphic on the datatype
|
||||
we will dispatch to this method from the main Apply
|
||||
*/
|
||||
template <typename T1, typename T2, typename T3>
|
||||
Status LabelEncoderFusion::ApplyHelper(
|
||||
Graph& graph,
|
||||
Node& node,
|
||||
Node& next_node,
|
||||
RewriteRuleEffect& rule_effect) const {
|
||||
ProtoHelperNodeContext node_helper_ctx(node);
|
||||
OpNodeProtoHelper<ProtoHelperNodeContext> node_helper(&node_helper_ctx);
|
||||
|
||||
ProtoHelperNodeContext next_node_helper_ctx(next_node);
|
||||
OpNodeProtoHelper<ProtoHelperNodeContext> next_node_helper(&next_node_helper_ctx);
|
||||
|
||||
const std::vector<T1> node_keys =
|
||||
node_helper.GetAttrsOrDefault<T1>(KEYS_ATTR_NAME(T1));
|
||||
const std::vector<T2> node_values =
|
||||
node_helper.GetAttrsOrDefault<T2>(VALUES_ATTR_NAME(T2));
|
||||
const T2 node_default =
|
||||
node_helper.GetAttr<T2>(DEFAULT_VALUE_ATTR_NAME(T2));
|
||||
|
||||
const std::vector<T2> next_node_keys =
|
||||
next_node_helper.GetAttrsOrDefault<T2>(KEYS_ATTR_NAME(T2));
|
||||
const std::vector<T3> next_node_values =
|
||||
next_node_helper.GetAttrsOrDefault<T3>(VALUES_ATTR_NAME(T3));
|
||||
const T3 next_node_default =
|
||||
next_node_helper.GetAttr<T3>(DEFAULT_VALUE_ATTR_NAME(T3));
|
||||
|
||||
const auto getFromMapDefault = [](const auto& mp, const auto key, const auto def) {
|
||||
return (mp.find(key) == mp.end()) ? def : mp.at(key);
|
||||
};
|
||||
|
||||
// Perform value propagation through the second label encoder
|
||||
std::unordered_map<T2, T3> mapping = {};
|
||||
for (size_t i = 0; i < next_node_keys.size(); i++) {
|
||||
mapping[next_node_keys[i]] = next_node_values[i];
|
||||
}
|
||||
|
||||
std::vector<T3> new_node_values = {};
|
||||
const auto new_node_default = getFromMapDefault(mapping, node_default, next_node_default);
|
||||
|
||||
for (const T2& node_value : node_values) {
|
||||
new_node_values.push_back(getFromMapDefault(mapping, node_value, next_node_default));
|
||||
}
|
||||
|
||||
// Remove old attributes:
|
||||
// The keys attribute is correct, we just reroute
|
||||
// the values
|
||||
node.ClearAttribute(VALUES_ATTR_NAME(T2));
|
||||
node.ClearAttribute(DEFAULT_VALUE_ATTR_NAME(T2));
|
||||
|
||||
node.AddAttribute(VALUES_ATTR_NAME(T3), new_node_values);
|
||||
node.AddAttribute(DEFAULT_VALUE_ATTR_NAME(T3), new_node_default);
|
||||
|
||||
graph_utils::FinalizeNodeFusion(graph, node, next_node);
|
||||
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define FUSE_IF_VALID(T1, T2, T3) \
|
||||
if (IsValidForFusion<T1, T2, T3>(node, next_node)) { \
|
||||
return ApplyHelper<T1, T2, T3>( \
|
||||
graph, node, next_node, rule_effect); \
|
||||
}
|
||||
|
||||
Status LabelEncoderFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
|
||||
auto& next_node = *graph.GetNode(node.OutputNodesBegin()->Index());
|
||||
|
||||
FUSE_IF_VALID(std::string, std::string, std::string);
|
||||
FUSE_IF_VALID(std::string, std::string, int64_t);
|
||||
FUSE_IF_VALID(std::string, int64_t, std::string);
|
||||
FUSE_IF_VALID(std::string, int64_t, int64_t);
|
||||
FUSE_IF_VALID(int64_t, std::string, std::string);
|
||||
FUSE_IF_VALID(int64_t, std::string, int64_t);
|
||||
FUSE_IF_VALID(int64_t, int64_t, std::string);
|
||||
FUSE_IF_VALID(int64_t, int64_t, int64_t);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
35
onnxruntime/core/optimizer/label_encoder_fusion.h
Normal file
35
onnxruntime/core/optimizer/label_encoder_fusion.h
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
/**
|
||||
@Class LabelEncoderFusion
|
||||
|
||||
Rewrite rule that fuses two LabelEncoder -> LabelEncoder nodes to a single
|
||||
LabelEncoder node.
|
||||
|
||||
*/
|
||||
class LabelEncoderFusion : public RewriteRule {
|
||||
public:
|
||||
LabelEncoderFusion() noexcept : RewriteRule("LabelEncoderFusion") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"LabelEncoder"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
Status ApplyHelper(Graph& graph, Node& node, Node& next_node, RewriteRuleEffect& rule_effect) const;
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
bool IsValidForFusion(const Node& node, const Node& next) const;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -69,6 +69,7 @@
|
|||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/label_encoder_fusion.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
|
|
@ -1901,6 +1902,68 @@ TEST_F(GraphTransformationTests, DivMulFusion) {
|
|||
ASSERT_TRUE(op_to_count["Mul"] == 2);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, LabelEncoderFusion) {
|
||||
using common::INVALID_GRAPH;
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/label_encoder.onnx";
|
||||
|
||||
NameMLValMap feeds;
|
||||
|
||||
constexpr size_t ALPH = 26;
|
||||
OrtValue mlvalue_a;
|
||||
std::vector<int64_t> dims_a = {ALPH};
|
||||
std::vector<std::string> values_a = {};
|
||||
for (char letter = 'a'; letter <= 'z'; letter++) {
|
||||
values_a.emplace_back(1, letter);
|
||||
}
|
||||
CreateMLValue<std::string>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_a,
|
||||
values_a, &mlvalue_a);
|
||||
feeds.insert(std::make_pair("A", mlvalue_a));
|
||||
|
||||
bool is_implemented = true;
|
||||
|
||||
auto run_model_test = [&](TransformerLevel level, std::vector<OrtValue>& fetches, const int requiredLabelEncoderCount) {
|
||||
SessionOptions session_options;
|
||||
session_options.graph_optimization_level = level;
|
||||
session_options.session_logid = "OptimizerTests";
|
||||
InferenceSessionWrapper session{session_options, GetEnvironment()};
|
||||
|
||||
// If we did not initialize the session correctly, the operator is missing.
|
||||
if (!session.Load(model_uri).IsOK() || !session.Initialize().IsOK()) {
|
||||
is_implemented = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Count if the number of LabelEncoders is as expected
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
ASSERT_TRUE(op_to_count["ai.onnx.ml.LabelEncoder"] == requiredLabelEncoderCount);
|
||||
|
||||
std::vector<std::string> output_names = {};
|
||||
for (const auto& output : session.GetGraph().GetOutputs()) {
|
||||
output_names.push_back(output->Name());
|
||||
}
|
||||
|
||||
RunOptions run_options;
|
||||
ASSERT_STATUS_OK(session.Run(run_options, feeds, output_names, &fetches));
|
||||
};
|
||||
|
||||
// run model with and w/o optimizations and compare the results
|
||||
std::vector<OrtValue> unoptimized_fetches;
|
||||
run_model_test(TransformerLevel::Default, unoptimized_fetches, 11);
|
||||
|
||||
std::vector<OrtValue> optimized_fetches;
|
||||
run_model_test(TransformerLevel::MaxLevel, optimized_fetches, 7);
|
||||
|
||||
// If there was a problem loading the model, do not compare the 2 results
|
||||
if (!is_implemented) {
|
||||
GTEST_SKIP();
|
||||
return;
|
||||
}
|
||||
|
||||
// Compare results
|
||||
auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], 0.0, 0.0, false);
|
||||
EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second;
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, NotWhereFusion) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/not_where.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/label_encoder.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/label_encoder.onnx
vendored
Normal file
Binary file not shown.
163
onnxruntime/test/testdata/transform/fusion/label_encoder.py
vendored
Normal file
163
onnxruntime/test/testdata/transform/fusion/label_encoder.py
vendored
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
from onnx import OperatorSetIdProto, TensorProto, helper, save
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 19
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
opsets.append(msdomain)
|
||||
|
||||
ai_ml_domain = OperatorSetIdProto()
|
||||
ai_ml_domain.version = 4
|
||||
ai_ml_domain.domain = "ai.onnx.ml"
|
||||
opsets.append(ai_ml_domain)
|
||||
|
||||
kwargs = {}
|
||||
kwargs["opset_imports"] = opsets
|
||||
|
||||
|
||||
def generate_model(model_name):
|
||||
# Create models with consecutive label encoders
|
||||
nodes = [ # subgraph
|
||||
# string -> int -> string
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["A"],
|
||||
["le_1_int_1"],
|
||||
"le_1_int_1",
|
||||
domain="ai.onnx.ml",
|
||||
keys_strings=["a", "b", "c"],
|
||||
values_int64s=[0, 1, 2],
|
||||
),
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["le_1_int_1"],
|
||||
["le_1_string_2"],
|
||||
"le_1_string_2",
|
||||
domain="ai.onnx.ml",
|
||||
keys_int64s=[2, 1, 0],
|
||||
values_strings=["a", "b", "c"],
|
||||
default_string="default",
|
||||
),
|
||||
# string -> string -> string
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["A"],
|
||||
["le_2_string_1"],
|
||||
"le_2_string_1",
|
||||
domain="ai.onnx.ml",
|
||||
keys_strings=["a", "b", "c"],
|
||||
values_strings=["C", "B", "A"],
|
||||
default_string="D",
|
||||
),
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["le_2_string_1"],
|
||||
["le_2_string_2"],
|
||||
"le_2_string_2",
|
||||
domain="ai.onnx.ml",
|
||||
keys_strings=["A", "B", "C", "D"],
|
||||
values_strings=["a", "b", "c", "d"],
|
||||
default_string="default",
|
||||
),
|
||||
# string -> string -> int -> string
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["A"],
|
||||
["le_3_string_1"],
|
||||
"le_3_string_1",
|
||||
domain="ai.onnx.ml",
|
||||
keys_strings=["a", "b", "c"],
|
||||
values_strings=["C", "B", "A"],
|
||||
),
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["le_3_string_1"],
|
||||
["le_3_int_2"],
|
||||
"le_3_int_2",
|
||||
domain="ai.onnx.ml",
|
||||
keys_strings=["A", "B", "C"],
|
||||
values_int64s=[1, 2, 3],
|
||||
default_int64=-1,
|
||||
),
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["le_3_int_2"],
|
||||
["le_3_string_3"],
|
||||
"le_3_string_3",
|
||||
domain="ai.onnx.ml",
|
||||
keys_int64s=[1, 2, 3],
|
||||
values_strings=["a", "b", "c"],
|
||||
default_string="d",
|
||||
),
|
||||
# middle encoder is graph output
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["A"],
|
||||
["le_4_int_1"],
|
||||
"le_4_int_1",
|
||||
domain="ai.onnx.ml",
|
||||
keys_strings=["a", "b", "c"],
|
||||
values_int64s=[0, 1, 2],
|
||||
),
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["le_4_int_1"],
|
||||
["le_4_string_2"],
|
||||
"le_4_string_2",
|
||||
domain="ai.onnx.ml",
|
||||
keys_int64s=[0, 1, 2],
|
||||
values_strings=["a", "b", "c"],
|
||||
),
|
||||
helper.make_node("Identity", ["le_4_int_1"], ["Y"], "output"),
|
||||
# middle encoder is consumed twice
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["A"],
|
||||
["le_5_int_1"],
|
||||
"le_5_int_1",
|
||||
domain="ai.onnx.ml",
|
||||
keys_strings=["a", "b", "c"],
|
||||
values_int64s=[0, 1, 2],
|
||||
),
|
||||
helper.make_node(
|
||||
"LabelEncoder",
|
||||
["le_5_int_1"],
|
||||
["le_5_string_2"],
|
||||
"le_5_string_2",
|
||||
domain="ai.onnx.ml",
|
||||
keys_int64s=[0, 1, 2],
|
||||
values_strings=["a", "b", "c"],
|
||||
),
|
||||
helper.make_node("Mul", ["le_5_int_1", "le_5_int_1"], ["mul_5"], "mul_5"),
|
||||
]
|
||||
|
||||
inputs = [ # inputs
|
||||
helper.make_tensor_value_info("A", TensorProto.STRING, ["N"]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"LabelEncoder", # name
|
||||
inputs,
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info("le_1_string_2", TensorProto.STRING, ["N"]),
|
||||
helper.make_tensor_value_info("le_2_string_2", TensorProto.STRING, ["N"]),
|
||||
helper.make_tensor_value_info("le_3_string_3", TensorProto.STRING, ["N"]),
|
||||
helper.make_tensor_value_info("le_4_string_2", TensorProto.STRING, ["N"]),
|
||||
helper.make_tensor_value_info("Y", TensorProto.INT64, ["N"]),
|
||||
helper.make_tensor_value_info("mul_5", TensorProto.INT64, ["N"]),
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
model = helper.make_model(graph, **kwargs)
|
||||
save(model, model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_model("label_encoder.onnx")
|
||||
Loading…
Reference in a new issue