mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
### Description Created a new `LabelEncoderFusion` pass. This is useful in model that result from automatic conversion tools related to data-science: sometimes the produced model contains consecutive `LabelEncoder`-s. To merge 2 `LabelEncoder`-s the optimizer propagates the outputs of the first encoder through the second one. ### Motivation and Context This enhances the capabilities of the `onnxruntime::optimizer` by fusing consecutive `LabelEncoder` nodes. ### Fusion examples ``` Applying fusion node1: (a,C) (b,B) (c,A) -> Default: _Unused node2: (A,1) (B,2) (C,3) -> Default: -1 fused: (a,3) (b,2) (c,1) -> Default: -1 Applying fusion node1: (a,C) (b,B) (c,A) -> Default: D node2: (A,a) (B,b) (C,c) (D,d) -> Default: default fused: (a,c) (b,b) (c,a) -> Default: d Applying fusion node1: (a,0) (b,1) (c,2) -> Default: -1 node2: (2,a) (1,b) (0,c) -> Default: default fused: (a,c) (b,b) (c,a) -> Default: default Applying fusion node1: (a,3) (b,2) (c,1) -> Default: -1 node2: (1,a) (2,b) (3,c) -> Default: d fused: (a,c) (b,b) (c,a) -> Default: d ``` --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
35 lines
1.1 KiB
C++
35 lines
1.1 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include "core/optimizer/rewrite_rule.h"
|
|
|
|
namespace onnxruntime {
|
|
/**
|
|
@Class LabelEncoderFusion
|
|
|
|
Rewrite rule that fuses two LabelEncoder -> LabelEncoder nodes to a single
|
|
LabelEncoder node.
|
|
|
|
*/
|
|
class LabelEncoderFusion : public RewriteRule {
|
|
public:
|
|
LabelEncoderFusion() noexcept : RewriteRule("LabelEncoderFusion") {}
|
|
|
|
std::vector<std::string> TargetOpTypes() const noexcept override {
|
|
return {"LabelEncoder"};
|
|
}
|
|
|
|
private:
|
|
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
|
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
|
|
|
template <typename T1, typename T2, typename T3>
|
|
Status ApplyHelper(Graph& graph, Node& node, Node& next_node, RewriteRuleEffect& rule_effect) const;
|
|
|
|
template <typename T1, typename T2, typename T3>
|
|
bool IsValidForFusion(const Node& node, const Node& next) const;
|
|
};
|
|
|
|
} // namespace onnxruntime
|