mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Convert Shape operator to initializer (#1159)
This PR introduces a rewrite rule that replaces a Shape node with an initializer when the shape of the input is statically known through shape inference.
This commit is contained in:
parent
cdb27de090
commit
32c6c71e86
8 changed files with 219 additions and 24 deletions
|
|
@ -539,6 +539,31 @@ TensorProto::DataType GetTensorProtoType(const Tensor& tensor) {
|
|||
return dtype;
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name,
|
||||
const onnx::TypeProto& tensor_proto_type) {
|
||||
// Given we are using the raw_data field in the protobuf, this will work only for little-endian format.
|
||||
ORT_ENFORCE(IsLittleEndianOrder());
|
||||
|
||||
// Set name, dimensions, type, and data of the TensorProto.
|
||||
ONNX_NAMESPACE::TensorProto tensor_proto;
|
||||
|
||||
tensor_proto.set_name(tensor_proto_name);
|
||||
|
||||
for (auto& dim : tensor.Shape().GetDims()) {
|
||||
tensor_proto.add_dims(dim);
|
||||
}
|
||||
|
||||
// TODO Once utils::GetTensorProtoType supports all data types, you can get the tensor proto type from the tensor,
|
||||
// as follows (which will allow us to get rid of the tensor_proto_type argument).
|
||||
//tensor_proto.set_data_type(utils::GetTensorProtoType(tensor));
|
||||
|
||||
tensor_proto.set_data_type(tensor_proto_type.tensor_type().elem_type());
|
||||
|
||||
tensor_proto.set_raw_data(tensor.DataRaw(), tensor.Size());
|
||||
|
||||
return tensor_proto;
|
||||
}
|
||||
|
||||
template common::Status GetSizeInBytesFromTensorProto<256>(const ONNX_NAMESPACE::TensorProto& tensor_proto,
|
||||
size_t* out);
|
||||
template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out);
|
||||
|
|
|
|||
|
|
@ -37,6 +37,18 @@ common::Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_prot
|
|||
// This function doesn't support string tensors
|
||||
ONNX_NAMESPACE::TensorProto::DataType GetTensorProtoType(const Tensor& tensor);
|
||||
|
||||
/** Creates a TensorProto from a Tensor.
|
||||
@param[in] tensor the Tensor whose data and shape will be used to create the TensorProto.
|
||||
@param[in] tensor_proto_name the name of the TensorProto.
|
||||
@param[in] tensor_proto_type the type of the TensorProto.
|
||||
@return the TensorProto.
|
||||
|
||||
Note: Method currently requires that data is in little-endian format.
|
||||
TODO Once the GetTensorProtoType supports all data types, we can remove the tensor_proto_type parameter and
|
||||
instead get the type from the tensor. */
|
||||
ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name,
|
||||
const onnx::TypeProto& tensor_proto_type);
|
||||
|
||||
ONNXTensorElementDataType CApiElementTypeFromProtoType(int type);
|
||||
ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto);
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/optimizer_execution_frame.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
|
|
@ -61,9 +61,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
|
|||
OrtValue& ort_value = fetches[fetch_idx];
|
||||
|
||||
// Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph.
|
||||
ONNX_NAMESPACE::TensorProto out_tensorproto;
|
||||
const auto* constant_arg_out = node->OutputDefs()[fetch_idx];
|
||||
BuildTensorProtoForInitializer(ort_value, *constant_arg_out, out_tensorproto);
|
||||
ORT_ENFORCE(ort_value.IsTensor());
|
||||
const Tensor& out_tensor = ort_value.Get<Tensor>();
|
||||
ONNX_NAMESPACE::TensorProto out_tensorproto =
|
||||
utils::TensorToTensorProto(out_tensor, constant_arg_out->Name(), *constant_arg_out->TypeAsProto());
|
||||
|
||||
graph.AddInitializedTensor(out_tensorproto);
|
||||
}
|
||||
|
|
@ -81,23 +83,4 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
|
|||
return Status::OK();
|
||||
} // namespace onnxruntime
|
||||
|
||||
void ConstantFolding::BuildTensorProtoForInitializer(const OrtValue& ort_value, const NodeArg& constant_node_arg,
|
||||
ONNX_NAMESPACE::TensorProto& tensorproto) const {
|
||||
ORT_ENFORCE(ort_value.IsTensor());
|
||||
const Tensor& out_tensor = ort_value.Get<Tensor>();
|
||||
|
||||
// Set name, dimensions, type, and data of the TensorProto.
|
||||
tensorproto.set_name(constant_node_arg.Name());
|
||||
|
||||
for (auto& dim : out_tensor.Shape().GetDims()) {
|
||||
tensorproto.add_dims(dim);
|
||||
}
|
||||
auto tensorproto_type = constant_node_arg.TypeAsProto()->tensor_type().elem_type();
|
||||
|
||||
tensorproto.set_data_type(tensorproto_type);
|
||||
auto tensor_shape_size = out_tensor.Shape().Size();
|
||||
auto data_size = out_tensor.DataType()->Size() * tensor_shape_size;
|
||||
tensorproto.set_raw_data(out_tensor.DataRaw(out_tensor.DataType()), data_size);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include "core/optimizer/matmul_add_fusion.h"
|
||||
#include "core/optimizer/dropout_elimination.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -32,6 +33,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
|
|||
rules.push_back(std::make_unique<UnsqueezeElimination>());
|
||||
rules.push_back(std::make_unique<EliminateDropout>());
|
||||
rules.push_back(std::make_unique<FuseReluClip>());
|
||||
rules.push_back(std::make_unique<ShapeToInitializer>());
|
||||
break;
|
||||
|
||||
case TransformerLevel::Level2:
|
||||
|
|
|
|||
78
onnxruntime/core/optimizer/shape_to_initializer.cc
Normal file
78
onnxruntime/core/optimizer/shape_to_initializer.cc
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/optimizer_execution_frame.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
// Store the statically inferred shape of the input to the Shape operator.
|
||||
const ONNX_NAMESPACE::TensorShapeProto* input_shape_proto = node.InputDefs()[0]->Shape();
|
||||
std::vector<int64_t> input_dims;
|
||||
int num_dimensions = input_shape_proto->dim_size();
|
||||
for (int i = 0; i < num_dimensions; i++) {
|
||||
input_dims.push_back(gsl::narrow_cast<int64_t>(input_shape_proto->dim(i).dim_value()));
|
||||
}
|
||||
|
||||
// Create the TensorProto that will be used as initializer in place of the Shape operator.
|
||||
const auto* shape_out_def = node.OutputDefs()[0];
|
||||
|
||||
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
|
||||
|
||||
shape_initializer_proto.set_name(shape_out_def->Name());
|
||||
|
||||
TensorShape tensor_shape({gsl::narrow_cast<int64_t>(num_dimensions)});
|
||||
for (auto& dim : tensor_shape.GetDims()) {
|
||||
shape_initializer_proto.add_dims(dim);
|
||||
}
|
||||
|
||||
auto tensor_proto_data_type = shape_out_def->TypeAsProto()->tensor_type().elem_type();
|
||||
|
||||
shape_initializer_proto.set_data_type(tensor_proto_data_type);
|
||||
|
||||
// Here we expect little-indian format to set raw data of the TensorProto.
|
||||
shape_initializer_proto.set_raw_data(input_dims.data(),
|
||||
input_dims.size() * sizeof(decltype(input_dims)::value_type));
|
||||
|
||||
// Remove the output edges of the Shape node, then remove the node itself, and replace it with the initializer.
|
||||
graph_utils::RemoveNodeOutputEdges(graph, node);
|
||||
|
||||
if (graph.RemoveNode(node.Index())) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
graph.AddInitializedTensor(shape_initializer_proto);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1}) ||
|
||||
// Making sure we are not left with a graph with no nodes.
|
||||
graph.IsNodeOutputsInGraphOutputs(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The shape of the input has to be statically known. Moreover, each dimension should have a specific value
|
||||
// (the rule cannot be applied if one of the dimension is a symbolic variable).
|
||||
const auto* input_shape = node.InputDefs()[0]->Shape();
|
||||
if (!input_shape) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0, num_dims = input_shape->dim_size(); i < num_dims; i++) {
|
||||
if (!input_shape->dim(i).has_dim_value()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
32
onnxruntime/core/optimizer/shape_to_initializer.h
Normal file
32
onnxruntime/core/optimizer/shape_to_initializer.h
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class ShapeToInitializer
|
||||
|
||||
When the input to a Shape operator is statically known (through shape inference), this rule replaces the Shape node
|
||||
with an initializer to the downstream nodes.
|
||||
|
||||
It is attempted to be triggered only on nodes with op type "Shape".
|
||||
*/
|
||||
class ShapeToInitializer : public RewriteRule {
|
||||
public:
|
||||
ShapeToInitializer() noexcept : RewriteRule("ShapeToInitializer") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Shape"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -28,6 +28,7 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
|
@ -65,7 +66,7 @@ TEST(GraphTransformationTests, IdentityElimination) {
|
|||
ASSERT_TRUE(op_to_count["Identity"] == 0);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, DropoutEliminationSingleOutput) {
|
||||
TEST(GraphTransformationTests, DropoutElimination) {
|
||||
string model_uri = MODEL_FOLDER + "dropout.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
|
||||
|
|
@ -107,7 +108,7 @@ TEST(GraphTransformationTests, SliceElimination) {
|
|||
ASSERT_TRUE(op_to_count["Slice"] == 4);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, ConstantFolding1) {
|
||||
TEST(GraphTransformationTests, ConstantFolding) {
|
||||
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
|
||||
|
|
@ -124,6 +125,26 @@ TEST(GraphTransformationTests, ConstantFolding1) {
|
|||
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, ShapeToInitializer) {
|
||||
string model_uri = MODEL_FOLDER + "shape-add.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 3);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
rule_transformer_L1->Register(std::make_unique<ShapeToInitializer>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
// One of the Shapes is not eliminated because it inlcludes a symbolic dimension.
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 1);
|
||||
}
|
||||
|
||||
// Check transformations in the case of a subgraph with constant inputs.
|
||||
TEST(GraphTransformationTests, SubgraphWithConstantInputs) {
|
||||
string model_uri = MODEL_FOLDER + "constant-subgraph.onnx";
|
||||
|
|
|
|||
42
onnxruntime/test/testdata/transform/shape-add.onnx
vendored
Normal file
42
onnxruntime/test/testdata/transform/shape-add.onnx
vendored
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
lotus-transfomrs:â
|
||||
|
||||
AC"Shape
|
||||
|
||||
BD"Shape
|
||||
|
||||
C
|
||||
DE"Add
|
||||
|
||||
EF"Shape
|
||||
|
||||
FG"Identity shape-addZ
|
||||
A
|
||||
|
||||
|
||||
|
||||
Z
|
||||
B
|
||||
|
||||
|
||||
N
|
||||
b
|
||||
G
|
||||
|
||||
|
||||
j
|
||||
C
|
||||
|
||||
|
||||
j
|
||||
D
|
||||
|
||||
|
||||
j
|
||||
E
|
||||
|
||||
|
||||
j
|
||||
F
|
||||
|
||||
|
||||
B
|
||||
Loading…
Reference in a new issue