Eliminate Useless Cast during Transformer. (#3606)

* Remove Useless Cast during Transformer.

* Resolve comments.

* Check if graph can remove the node.

Co-authored-by: Vincent Wang <weicwang@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Vincent Wang 2020-04-22 16:36:46 +08:00 committed by GitHub
parent 5492d02c4e
commit d3a2ac5c5c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 136 additions and 0 deletions

View file

@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/logging/logging.h"
#include "core/optimizer/rewrite_rule.h"
#include "core/optimizer/cast_elimination.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph.h"
#include "core/graph/graph_utils.h"
namespace onnxruntime {
Status CastElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
return Status::OK();
}
bool CastElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
return false;
}
const auto* input_type = node.InputDefs()[0]->TypeAsProto();
if (input_type == nullptr || !input_type->tensor_type().has_elem_type()) {
return false;
}
return optimizer_utils::IsAttributeWithExpectedValue(node, "to", static_cast<int64_t>(input_type->tensor_type().elem_type()));
}
} // namespace onnxruntime

View file

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class CastElimination
Rewrite rule that eliminates Cast nodes if 'to' attribute has same data type as input tensor data type.
It is attempted to be triggered only on nodes with op type "Cast".
*/
class CastElimination : public RewriteRule {
public:
CastElimination() noexcept : RewriteRule("CastElimination") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Cast"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -27,6 +27,7 @@
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/attention_fusion.h"
#include "core/optimizer/cast_elimination.h"
#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
@ -46,6 +47,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
rules.push_back(onnxruntime::make_unique<EliminateSlice>());
rules.push_back(onnxruntime::make_unique<UnsqueezeElimination>());
rules.push_back(onnxruntime::make_unique<EliminateDropout>());
rules.push_back(onnxruntime::make_unique<CastElimination>());
rules.push_back(onnxruntime::make_unique<FuseReluClip>());
rules.push_back(onnxruntime::make_unique<ShapeToInitializer>());
rules.push_back(onnxruntime::make_unique<ConvAddFusion>());

View file

@ -40,6 +40,7 @@
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/attention_fusion.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/cast_elimination.h"
#include "core/optimizer/utils.h"
#include "core/platform/env.h"
#include "core/util/math.h"
@ -1038,6 +1039,24 @@ TEST(GraphTransformationTests, ReshapeFusionOneConstTest) {
}
}
TEST(GraphTransformationTests, CastElimination) {
auto model_uri = MODEL_FOLDER "cast_elimination.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Cast"] == 7);
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(onnxruntime::make_unique<CastElimination>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK());
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Cast"] == 4);
}
#ifndef DISABLE_CONTRIB_OPS
static void ValidateAttention(Graph& graph) {

Binary file not shown.

View file

@ -0,0 +1,48 @@
import onnx
from onnx import helper
from onnx import TensorProto, GraphProto, OperatorSetIdProto
from onnx import numpy_helper
import numpy as np
X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [4, 4])
X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [4, 1])
X3 = helper.make_tensor_value_info('x3', TensorProto.INT64, [4, 1])
Y = helper.make_tensor_value_info('output', TensorProto.INT64, [4, 4])
less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1')
less2 = helper.make_node('Less', ['x1', 'x3'], ['less2'], name='less2')
cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=9)
and_node = helper.make_node('And', ['cast1', 'less2'], ['and_node'], name='and_node')
cast2 = helper.make_node('Cast', ['and_node'], ['cast2'], name='cast2', to=9)
cast3 = helper.make_node('Cast', ['cast2'], ['cast3'], name='cast3', to=1)
cast4 = helper.make_node('Cast', ['x1'], ['cast4'], name='cast4', to=7)
cast5 = helper.make_node('Cast', ['cast4'], ['cast5'], name='cast5', to=1)
matmul = helper.make_node('MatMul', ['cast3', 'cast5'], ['matmul'], name='matmul')
cast6 = helper.make_node('Cast', ['matmul'], ['cast6'], name='cast6', to=7)
cast7 = helper.make_node('Cast', ['cast6'], ['output'], name='cast7', to=7)
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[less1, less2, cast1, and_node, cast2, cast3, cast4, cast5, matmul, cast6, cast7],
'cast_elimination_model',
[X1, X2, X3],
[Y]
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = 'com.microsoft'
opsets.append(msdomain)
kwargs={}
kwargs['opset_imports'] = opsets
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
onnx.save(model_def, 'cast_elimination.onnx')

View file

@ -12,6 +12,7 @@
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/cast_elimination.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/conv_activation_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
@ -53,6 +54,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(T
rule_transformer->Register(make_unique<InsertMaxPoolOutput>());
rule_transformer->Register(make_unique<AdjustBatchNormOutputs>());
rule_transformer->Register(make_unique<UnsqueezeElimination>());
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));