onnxruntime/onnxruntime/core/optimizer/label_encoder_fusion.h
Atanas Dimitrov 9d06e1bfa4
Label encoder fusion (#19761)
### Description
Created a new `LabelEncoderFusion` pass. This is useful in model that
result from automatic conversion tools related to data-science:
sometimes the produced model contains consecutive `LabelEncoder`-s.
To merge 2 `LabelEncoder`-s the optimizer propagates the outputs of the
first encoder through the second one.


### Motivation and Context
This enhances the capabilities of the `onnxruntime::optimizer` by fusing
consecutive `LabelEncoder` nodes.


### Fusion examples
```
Applying fusion
node1: (a,C) (b,B) (c,A) -> Default: _Unused
node2: (A,1) (B,2) (C,3) -> Default: -1
fused: (a,3) (b,2) (c,1) -> Default: -1
Applying fusion
node1: (a,C) (b,B) (c,A) -> Default: D
node2: (A,a) (B,b) (C,c) (D,d) -> Default: default
fused: (a,c) (b,b) (c,a) -> Default: d
Applying fusion
node1: (a,0) (b,1) (c,2) -> Default: -1
node2: (2,a) (1,b) (0,c) -> Default: default
fused: (a,c) (b,b) (c,a) -> Default: default
Applying fusion
node1: (a,3) (b,2) (c,1) -> Default: -1
node2: (1,a) (2,b) (3,c) -> Default: d
fused: (a,c) (b,b) (c,a) -> Default: d
```

---------

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
2024-04-01 09:41:10 -07:00

35 lines
1.1 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class LabelEncoderFusion
Rewrite rule that fuses two LabelEncoder -> LabelEncoder nodes to a single
LabelEncoder node.
*/
class LabelEncoderFusion : public RewriteRule {
public:
LabelEncoderFusion() noexcept : RewriteRule("LabelEncoderFusion") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"LabelEncoder"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
template <typename T1, typename T2, typename T3>
Status ApplyHelper(Graph& graph, Node& node, Node& next_node, RewriteRuleEffect& rule_effect) const;
template <typename T1, typename T2, typename T3>
bool IsValidForFusion(const Node& node, const Node& next) const;
};
} // namespace onnxruntime