mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Eliminate Useless Cast during Transformer. (#3606)
* Remove Useless Cast during Transformer. * Resolve comments. * Check if graph can remove the node. Co-authored-by: Vincent Wang <weicwang@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
5492d02c4e
commit
d3a2ac5c5c
7 changed files with 136 additions and 0 deletions
34
onnxruntime/core/optimizer/cast_elimination.cc
Normal file
34
onnxruntime/core/optimizer/cast_elimination.cc
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
#include "core/optimizer/cast_elimination.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status CastElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool CastElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto* input_type = node.InputDefs()[0]->TypeAsProto();
|
||||
if (input_type == nullptr || !input_type->tensor_type().has_elem_type()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return optimizer_utils::IsAttributeWithExpectedValue(node, "to", static_cast<int64_t>(input_type->tensor_type().elem_type()));
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
31
onnxruntime/core/optimizer/cast_elimination.h
Normal file
31
onnxruntime/core/optimizer/cast_elimination.h
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class CastElimination
|
||||
|
||||
Rewrite rule that eliminates Cast nodes if 'to' attribute has same data type as input tensor data type.
|
||||
|
||||
It is attempted to be triggered only on nodes with op type "Cast".
|
||||
*/
|
||||
class CastElimination : public RewriteRule {
|
||||
public:
|
||||
CastElimination() noexcept : RewriteRule("CastElimination") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Cast"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -27,6 +27,7 @@
|
|||
#include "core/optimizer/embed_layer_norm_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/attention_fusion.h"
|
||||
#include "core/optimizer/cast_elimination.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -46,6 +47,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
|
|||
rules.push_back(onnxruntime::make_unique<EliminateSlice>());
|
||||
rules.push_back(onnxruntime::make_unique<UnsqueezeElimination>());
|
||||
rules.push_back(onnxruntime::make_unique<EliminateDropout>());
|
||||
rules.push_back(onnxruntime::make_unique<CastElimination>());
|
||||
rules.push_back(onnxruntime::make_unique<FuseReluClip>());
|
||||
rules.push_back(onnxruntime::make_unique<ShapeToInitializer>());
|
||||
rules.push_back(onnxruntime::make_unique<ConvAddFusion>());
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@
|
|||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/attention_fusion.h"
|
||||
#include "core/optimizer/fast_gelu_fusion.h"
|
||||
#include "core/optimizer/cast_elimination.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/util/math.h"
|
||||
|
|
@ -1038,6 +1039,24 @@ TEST(GraphTransformationTests, ReshapeFusionOneConstTest) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, CastElimination) {
|
||||
auto model_uri = MODEL_FOLDER "cast_elimination.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 7);
|
||||
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(onnxruntime::make_unique<CastElimination>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 4);
|
||||
}
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
|
||||
static void ValidateAttention(Graph& graph) {
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/cast_elimination.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/cast_elimination.onnx
vendored
Normal file
Binary file not shown.
48
onnxruntime/test/testdata/transform/cast_elimination.py
vendored
Normal file
48
onnxruntime/test/testdata/transform/cast_elimination.py
vendored
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [4, 4])
|
||||
X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [4, 1])
|
||||
X3 = helper.make_tensor_value_info('x3', TensorProto.INT64, [4, 1])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.INT64, [4, 4])
|
||||
|
||||
less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1')
|
||||
less2 = helper.make_node('Less', ['x1', 'x3'], ['less2'], name='less2')
|
||||
cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=9)
|
||||
and_node = helper.make_node('And', ['cast1', 'less2'], ['and_node'], name='and_node')
|
||||
cast2 = helper.make_node('Cast', ['and_node'], ['cast2'], name='cast2', to=9)
|
||||
cast3 = helper.make_node('Cast', ['cast2'], ['cast3'], name='cast3', to=1)
|
||||
cast4 = helper.make_node('Cast', ['x1'], ['cast4'], name='cast4', to=7)
|
||||
cast5 = helper.make_node('Cast', ['cast4'], ['cast5'], name='cast5', to=1)
|
||||
matmul = helper.make_node('MatMul', ['cast3', 'cast5'], ['matmul'], name='matmul')
|
||||
cast6 = helper.make_node('Cast', ['matmul'], ['cast6'], name='cast6', to=7)
|
||||
cast7 = helper.make_node('Cast', ['cast6'], ['output'], name='cast7', to=7)
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[less1, less2, cast1, and_node, cast2, cast3, cast4, cast5, matmul, cast6, cast7],
|
||||
'cast_elimination_model',
|
||||
[X1, X2, X3],
|
||||
[Y]
|
||||
)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = 'com.microsoft'
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs={}
|
||||
kwargs['opset_imports'] = opsets
|
||||
|
||||
# Create the model (ModelProto)
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
onnx.save(model_def, 'cast_elimination.onnx')
|
||||
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/optimizer/conv_add_fusion.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
#include "core/optimizer/cast_elimination.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/conv_activation_fusion.h"
|
||||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
|
|
@ -53,6 +54,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(T
|
|||
rule_transformer->Register(make_unique<InsertMaxPoolOutput>());
|
||||
rule_transformer->Register(make_unique<AdjustBatchNormOutputs>());
|
||||
rule_transformer->Register(make_unique<UnsqueezeElimination>());
|
||||
rule_transformer->Register(make_unique<CastElimination>());
|
||||
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
|
||||
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));
|
||||
|
|
|
|||
Loading…
Reference in a new issue