diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 9b9278f62b..83f18161c2 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -539,6 +539,31 @@ TensorProto::DataType GetTensorProtoType(const Tensor& tensor) { return dtype; } +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name, + const onnx::TypeProto& tensor_proto_type) { + // Given we are using the raw_data field in the protobuf, this will work only for little-endian format. + ORT_ENFORCE(IsLittleEndianOrder()); + + // Set name, dimensions, type, and data of the TensorProto. + ONNX_NAMESPACE::TensorProto tensor_proto; + + tensor_proto.set_name(tensor_proto_name); + + for (auto& dim : tensor.Shape().GetDims()) { + tensor_proto.add_dims(dim); + } + + // TODO Once utils::GetTensorProtoType supports all data types, you can get the tensor proto type from the tensor, + // as follows (which will allow us to get rid of the tensor_proto_type argument). + //tensor_proto.set_data_type(utils::GetTensorProtoType(tensor)); + + tensor_proto.set_data_type(tensor_proto_type.tensor_type().elem_type()); + + tensor_proto.set_raw_data(tensor.DataRaw(), tensor.Size()); + + return tensor_proto; +} + template common::Status GetSizeInBytesFromTensorProto<256>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out); template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out); diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 2e7a3b1214..cd11184b34 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -37,6 +37,18 @@ common::Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_prot // This function doesn't support string tensors ONNX_NAMESPACE::TensorProto::DataType GetTensorProtoType(const Tensor& tensor); +/** Creates a TensorProto from a Tensor. + @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. + @param[in] tensor_proto_name the name of the TensorProto. + @param[in] tensor_proto_type the type of the TensorProto. + @return the TensorProto. + + Note: Method currently requires that data is in little-endian format. + TODO Once the GetTensorProtoType supports all data types, we can remove the tensor_proto_type parameter and + instead get the type from the tensor. */ +ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name, + const onnx::TypeProto& tensor_proto_type); + ONNXTensorElementDataType CApiElementTypeFromProtoType(int type); ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto); diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 8e5f90d63c..db5cfc2d76 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -5,7 +5,7 @@ #include "core/graph/graph_utils.h" #include "core/optimizer/optimizer_execution_frame.h" #include "core/framework/op_kernel.h" -#include "core/framework/ml_value.h" +#include "core/framework/tensorprotoutils.h" using namespace onnxruntime::common; @@ -61,9 +61,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level) OrtValue& ort_value = fetches[fetch_idx]; // Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph. - ONNX_NAMESPACE::TensorProto out_tensorproto; const auto* constant_arg_out = node->OutputDefs()[fetch_idx]; - BuildTensorProtoForInitializer(ort_value, *constant_arg_out, out_tensorproto); + ORT_ENFORCE(ort_value.IsTensor()); + const Tensor& out_tensor = ort_value.Get(); + ONNX_NAMESPACE::TensorProto out_tensorproto = + utils::TensorToTensorProto(out_tensor, constant_arg_out->Name(), *constant_arg_out->TypeAsProto()); graph.AddInitializedTensor(out_tensorproto); } @@ -81,23 +83,4 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level) return Status::OK(); } // namespace onnxruntime -void ConstantFolding::BuildTensorProtoForInitializer(const OrtValue& ort_value, const NodeArg& constant_node_arg, - ONNX_NAMESPACE::TensorProto& tensorproto) const { - ORT_ENFORCE(ort_value.IsTensor()); - const Tensor& out_tensor = ort_value.Get(); - - // Set name, dimensions, type, and data of the TensorProto. - tensorproto.set_name(constant_node_arg.Name()); - - for (auto& dim : out_tensor.Shape().GetDims()) { - tensorproto.add_dims(dim); - } - auto tensorproto_type = constant_node_arg.TypeAsProto()->tensor_type().elem_type(); - - tensorproto.set_data_type(tensorproto_type); - auto tensor_shape_size = out_tensor.Shape().Size(); - auto data_size = out_tensor.DataType()->Size() * tensor_shape_size; - tensorproto.set_raw_data(out_tensor.DataRaw(out_tensor.DataType()), data_size); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 5c03fc3538..262241c599 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -13,6 +13,7 @@ #include "core/optimizer/matmul_add_fusion.h" #include "core/optimizer/dropout_elimination.h" #include "core/optimizer/relu_clip_fusion.h" +#include "core/optimizer/shape_to_initializer.h" namespace onnxruntime { @@ -32,6 +33,7 @@ std::vector> GenerateRewriteRules(TransformerLevel rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); break; case TransformerLevel::Level2: diff --git a/onnxruntime/core/optimizer/shape_to_initializer.cc b/onnxruntime/core/optimizer/shape_to_initializer.cc new file mode 100644 index 0000000000..66a8a77575 --- /dev/null +++ b/onnxruntime/core/optimizer/shape_to_initializer.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/shape_to_initializer.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" +#include "core/graph/op.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/optimizer_execution_frame.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensorprotoutils.h" + +namespace onnxruntime { + +Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { + // Store the statically inferred shape of the input to the Shape operator. + const ONNX_NAMESPACE::TensorShapeProto* input_shape_proto = node.InputDefs()[0]->Shape(); + std::vector input_dims; + int num_dimensions = input_shape_proto->dim_size(); + for (int i = 0; i < num_dimensions; i++) { + input_dims.push_back(gsl::narrow_cast(input_shape_proto->dim(i).dim_value())); + } + + // Create the TensorProto that will be used as initializer in place of the Shape operator. + const auto* shape_out_def = node.OutputDefs()[0]; + + ONNX_NAMESPACE::TensorProto shape_initializer_proto; + + shape_initializer_proto.set_name(shape_out_def->Name()); + + TensorShape tensor_shape({gsl::narrow_cast(num_dimensions)}); + for (auto& dim : tensor_shape.GetDims()) { + shape_initializer_proto.add_dims(dim); + } + + auto tensor_proto_data_type = shape_out_def->TypeAsProto()->tensor_type().elem_type(); + + shape_initializer_proto.set_data_type(tensor_proto_data_type); + + // Here we expect little-indian format to set raw data of the TensorProto. + shape_initializer_proto.set_raw_data(input_dims.data(), + input_dims.size() * sizeof(decltype(input_dims)::value_type)); + + // Remove the output edges of the Shape node, then remove the node itself, and replace it with the initializer. + graph_utils::RemoveNodeOutputEdges(graph, node); + + if (graph.RemoveNode(node.Index())) { + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + graph.AddInitializedTensor(shape_initializer_proto); + } + + return Status::OK(); +} + +bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1}) || + // Making sure we are not left with a graph with no nodes. + graph.IsNodeOutputsInGraphOutputs(node)) { + return false; + } + + // The shape of the input has to be statically known. Moreover, each dimension should have a specific value + // (the rule cannot be applied if one of the dimension is a symbolic variable). + const auto* input_shape = node.InputDefs()[0]->Shape(); + if (!input_shape) { + return false; + } + + for (int i = 0, num_dims = input_shape->dim_size(); i < num_dims; i++) { + if (!input_shape->dim(i).has_dim_value()) { + return false; + } + } + + return true; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/shape_to_initializer.h b/onnxruntime/core/optimizer/shape_to_initializer.h new file mode 100644 index 0000000000..751bd22d96 --- /dev/null +++ b/onnxruntime/core/optimizer/shape_to_initializer.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class ShapeToInitializer + +When the input to a Shape operator is statically known (through shape inference), this rule replaces the Shape node +with an initializer to the downstream nodes. + +It is attempted to be triggered only on nodes with op type "Shape". +*/ +class ShapeToInitializer : public RewriteRule { + public: + ShapeToInitializer() noexcept : RewriteRule("ShapeToInitializer") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Shape"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e8a33e5fb6..16febfd106 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -28,6 +28,7 @@ #include "gtest/gtest.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/constant_folding.h" +#include "core/optimizer/shape_to_initializer.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -65,7 +66,7 @@ TEST(GraphTransformationTests, IdentityElimination) { ASSERT_TRUE(op_to_count["Identity"] == 0); } -TEST(GraphTransformationTests, DropoutEliminationSingleOutput) { +TEST(GraphTransformationTests, DropoutElimination) { string model_uri = MODEL_FOLDER + "dropout.onnx"; std::shared_ptr model; ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); @@ -107,7 +108,7 @@ TEST(GraphTransformationTests, SliceElimination) { ASSERT_TRUE(op_to_count["Slice"] == 4); } -TEST(GraphTransformationTests, ConstantFolding1) { +TEST(GraphTransformationTests, ConstantFolding) { string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; std::shared_ptr model; ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); @@ -124,6 +125,26 @@ TEST(GraphTransformationTests, ConstantFolding1) { ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); } +TEST(GraphTransformationTests, ShapeToInitializer) { + string model_uri = MODEL_FOLDER + "shape-add.onnx"; + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Shape"] == 3); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + rule_transformer_L1->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + + op_to_count = CountOpsInGraph(graph); + // One of the Shapes is not eliminated because it inlcludes a symbolic dimension. + ASSERT_TRUE(op_to_count["Shape"] == 1); +} + // Check transformations in the case of a subgraph with constant inputs. TEST(GraphTransformationTests, SubgraphWithConstantInputs) { string model_uri = MODEL_FOLDER + "constant-subgraph.onnx"; diff --git a/onnxruntime/test/testdata/transform/shape-add.onnx b/onnxruntime/test/testdata/transform/shape-add.onnx new file mode 100644 index 0000000000..bbe25966a2 --- /dev/null +++ b/onnxruntime/test/testdata/transform/shape-add.onnx @@ -0,0 +1,42 @@ +lotus-transfomrs:â + +AC"Shape + +BD"Shape + +C +DE"Add + +EF"Shape + +FG"Identity shape-addZ +A + + + +Z +B + + +N +b +G + + +j +C + + +j +D + + +j +E + + +j +F + + +B