From d3a2ac5c5cfdb78da5c9c8f770e940307f958af4 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 22 Apr 2020 16:36:46 +0800 Subject: [PATCH] Eliminate Useless Cast during Transformer. (#3606) * Remove Useless Cast during Transformer. * Resolve comments. * Check if graph can remove the node. Co-authored-by: Vincent Wang --- .../core/optimizer/cast_elimination.cc | 34 +++++++++++++ onnxruntime/core/optimizer/cast_elimination.h | 31 +++++++++++ .../core/optimizer/graph_transformer_utils.cc | 2 + .../test/optimizer/graph_transform_test.cc | 19 +++++++ .../testdata/transform/cast_elimination.onnx | Bin 0 -> 583 bytes .../testdata/transform/cast_elimination.py | 48 ++++++++++++++++++ .../core/optimizer/graph_transformer_utils.cc | 2 + 7 files changed, 136 insertions(+) create mode 100644 onnxruntime/core/optimizer/cast_elimination.cc create mode 100644 onnxruntime/core/optimizer/cast_elimination.h create mode 100644 onnxruntime/test/testdata/transform/cast_elimination.onnx create mode 100644 onnxruntime/test/testdata/transform/cast_elimination.py diff --git a/onnxruntime/core/optimizer/cast_elimination.cc b/onnxruntime/core/optimizer/cast_elimination.cc new file mode 100644 index 0000000000..bbcd93472e --- /dev/null +++ b/onnxruntime/core/optimizer/cast_elimination.cc @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/logging/logging.h" +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/cast_elimination.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +Status CastElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + if (graph_utils::RemoveNode(graph, node)) { + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + } + + return Status::OK(); +} + +bool CastElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { + if (!graph_utils::CanRemoveNode(graph, node, logger)) { + return false; + } + + const auto* input_type = node.InputDefs()[0]->TypeAsProto(); + if (input_type == nullptr || !input_type->tensor_type().has_elem_type()) { + return false; + } + + return optimizer_utils::IsAttributeWithExpectedValue(node, "to", static_cast(input_type->tensor_type().elem_type())); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/cast_elimination.h b/onnxruntime/core/optimizer/cast_elimination.h new file mode 100644 index 0000000000..f1b880d678 --- /dev/null +++ b/onnxruntime/core/optimizer/cast_elimination.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class CastElimination + +Rewrite rule that eliminates Cast nodes if 'to' attribute has same data type as input tensor data type. + +It is attempted to be triggered only on nodes with op type "Cast". +*/ +class CastElimination : public RewriteRule { + public: + CastElimination() noexcept : RewriteRule("CastElimination") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Cast"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 5d250478dd..bb2255bfcd 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -27,6 +27,7 @@ #include "core/optimizer/embed_layer_norm_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/attention_fusion.h" +#include "core/optimizer/cast_elimination.h" #include "core/mlas/inc/mlas.h" namespace onnxruntime { @@ -46,6 +47,7 @@ std::vector> GenerateRewriteRules(TransformerLevel rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); + rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 54d8cdfa03..7447893946 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -40,6 +40,7 @@ #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/fast_gelu_fusion.h" +#include "core/optimizer/cast_elimination.h" #include "core/optimizer/utils.h" #include "core/platform/env.h" #include "core/util/math.h" @@ -1038,6 +1039,24 @@ TEST(GraphTransformationTests, ReshapeFusionOneConstTest) { } } +TEST(GraphTransformationTests, CastElimination) { + auto model_uri = MODEL_FOLDER "cast_elimination.onnx"; + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Cast"] == 7); + + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Cast"] == 4); +} + #ifndef DISABLE_CONTRIB_OPS static void ValidateAttention(Graph& graph) { diff --git a/onnxruntime/test/testdata/transform/cast_elimination.onnx b/onnxruntime/test/testdata/transform/cast_elimination.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0ec37d79769ccb7103d36cb9961ea432d13fc7c5 GIT binary patch literal 583 zcmZXQ%}T^D6oqZ=bUK$RV-OjNs|t#O!t_sY=eTuVz)h)bg+hJ?o5H+{8z0pqHxybI z;Wp=dJ%^jXC!yf{_}CsRcA)Lcj|bP_9=GZwWI@}$-$poM?_J$SH$^u=H?i(5AWRl@ z)l2oO8L{`KZVxMPrP!=~9T)BZ9mEq7RJ$^c={MMAyd^c hYC2axH(@JAXojiQE7H literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/cast_elimination.py b/onnxruntime/test/testdata/transform/cast_elimination.py new file mode 100644 index 0000000000..3d546c85e9 --- /dev/null +++ b/onnxruntime/test/testdata/transform/cast_elimination.py @@ -0,0 +1,48 @@ +import onnx +from onnx import helper +from onnx import TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np + +X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [4, 4]) +X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [4, 1]) +X3 = helper.make_tensor_value_info('x3', TensorProto.INT64, [4, 1]) +Y = helper.make_tensor_value_info('output', TensorProto.INT64, [4, 4]) + +less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1') +less2 = helper.make_node('Less', ['x1', 'x3'], ['less2'], name='less2') +cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=9) +and_node = helper.make_node('And', ['cast1', 'less2'], ['and_node'], name='and_node') +cast2 = helper.make_node('Cast', ['and_node'], ['cast2'], name='cast2', to=9) +cast3 = helper.make_node('Cast', ['cast2'], ['cast3'], name='cast3', to=1) +cast4 = helper.make_node('Cast', ['x1'], ['cast4'], name='cast4', to=7) +cast5 = helper.make_node('Cast', ['cast4'], ['cast5'], name='cast5', to=1) +matmul = helper.make_node('MatMul', ['cast3', 'cast5'], ['matmul'], name='matmul') +cast6 = helper.make_node('Cast', ['matmul'], ['cast6'], name='cast6', to=7) +cast7 = helper.make_node('Cast', ['cast6'], ['output'], name='cast7', to=7) + +# Create the graph (GraphProto) +graph_def = helper.make_graph( + [less1, less2, cast1, and_node, cast2, cast3, cast4, cast5, matmul, cast6, cast7], + 'cast_elimination_model', + [X1, X2, X3], + [Y] +) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = 'com.microsoft' + +opsets.append(msdomain) +kwargs={} +kwargs['opset_imports'] = opsets + +# Create the model (ModelProto) +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +onnx.save(model_def, 'cast_elimination.onnx') diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index df3088efdb..27bfdb29cf 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -12,6 +12,7 @@ #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/constant_folding.h" #include "core/optimizer/unsqueeze_elimination.h" +#include "core/optimizer/cast_elimination.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/conv_activation_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -53,6 +54,7 @@ std::vector> GeneratePreTrainingTransformers(T rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); + rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); transformers.emplace_back(onnxruntime::make_unique(compatible_eps));