Convert Shape operator to initializer (#1159)

This PR introduces a rewrite rule that replaces a Shape node with an initializer when the shape of the input is statically known through shape inference.
This commit is contained in:
Konstantinos Karanasos 2019-06-07 14:15:19 -07:00 committed by Changming Sun
parent cdb27de090
commit 32c6c71e86
8 changed files with 219 additions and 24 deletions

View file

@ -539,6 +539,31 @@ TensorProto::DataType GetTensorProtoType(const Tensor& tensor) {
return dtype;
}
ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name,
const onnx::TypeProto& tensor_proto_type) {
// Given we are using the raw_data field in the protobuf, this will work only for little-endian format.
ORT_ENFORCE(IsLittleEndianOrder());
// Set name, dimensions, type, and data of the TensorProto.
ONNX_NAMESPACE::TensorProto tensor_proto;
tensor_proto.set_name(tensor_proto_name);
for (auto& dim : tensor.Shape().GetDims()) {
tensor_proto.add_dims(dim);
}
// TODO Once utils::GetTensorProtoType supports all data types, you can get the tensor proto type from the tensor,
// as follows (which will allow us to get rid of the tensor_proto_type argument).
//tensor_proto.set_data_type(utils::GetTensorProtoType(tensor));
tensor_proto.set_data_type(tensor_proto_type.tensor_type().elem_type());
tensor_proto.set_raw_data(tensor.DataRaw(), tensor.Size());
return tensor_proto;
}
template common::Status GetSizeInBytesFromTensorProto<256>(const ONNX_NAMESPACE::TensorProto& tensor_proto,
size_t* out);
template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out);

View file

@ -37,6 +37,18 @@ common::Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_prot
// This function doesn't support string tensors
ONNX_NAMESPACE::TensorProto::DataType GetTensorProtoType(const Tensor& tensor);
/** Creates a TensorProto from a Tensor.
@param[in] tensor the Tensor whose data and shape will be used to create the TensorProto.
@param[in] tensor_proto_name the name of the TensorProto.
@param[in] tensor_proto_type the type of the TensorProto.
@return the TensorProto.
Note: Method currently requires that data is in little-endian format.
TODO Once the GetTensorProtoType supports all data types, we can remove the tensor_proto_type parameter and
instead get the type from the tensor. */
ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name,
const onnx::TypeProto& tensor_proto_type);
ONNXTensorElementDataType CApiElementTypeFromProtoType(int type);
ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto);

View file

@ -5,7 +5,7 @@
#include "core/graph/graph_utils.h"
#include "core/optimizer/optimizer_execution_frame.h"
#include "core/framework/op_kernel.h"
#include "core/framework/ml_value.h"
#include "core/framework/tensorprotoutils.h"
using namespace onnxruntime::common;
@ -61,9 +61,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
OrtValue& ort_value = fetches[fetch_idx];
// Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph.
ONNX_NAMESPACE::TensorProto out_tensorproto;
const auto* constant_arg_out = node->OutputDefs()[fetch_idx];
BuildTensorProtoForInitializer(ort_value, *constant_arg_out, out_tensorproto);
ORT_ENFORCE(ort_value.IsTensor());
const Tensor& out_tensor = ort_value.Get<Tensor>();
ONNX_NAMESPACE::TensorProto out_tensorproto =
utils::TensorToTensorProto(out_tensor, constant_arg_out->Name(), *constant_arg_out->TypeAsProto());
graph.AddInitializedTensor(out_tensorproto);
}
@ -81,23 +83,4 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level)
return Status::OK();
} // namespace onnxruntime
void ConstantFolding::BuildTensorProtoForInitializer(const OrtValue& ort_value, const NodeArg& constant_node_arg,
ONNX_NAMESPACE::TensorProto& tensorproto) const {
ORT_ENFORCE(ort_value.IsTensor());
const Tensor& out_tensor = ort_value.Get<Tensor>();
// Set name, dimensions, type, and data of the TensorProto.
tensorproto.set_name(constant_node_arg.Name());
for (auto& dim : out_tensor.Shape().GetDims()) {
tensorproto.add_dims(dim);
}
auto tensorproto_type = constant_node_arg.TypeAsProto()->tensor_type().elem_type();
tensorproto.set_data_type(tensorproto_type);
auto tensor_shape_size = out_tensor.Shape().Size();
auto data_size = out_tensor.DataType()->Size() * tensor_shape_size;
tensorproto.set_raw_data(out_tensor.DataRaw(out_tensor.DataType()), data_size);
}
} // namespace onnxruntime

View file

@ -13,6 +13,7 @@
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/shape_to_initializer.h"
namespace onnxruntime {
@ -32,6 +33,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
rules.push_back(std::make_unique<UnsqueezeElimination>());
rules.push_back(std::make_unique<EliminateDropout>());
rules.push_back(std::make_unique<FuseReluClip>());
rules.push_back(std::make_unique<ShapeToInitializer>());
break;
case TransformerLevel::Level2:

View file

@ -0,0 +1,78 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/shape_to_initializer.h"
#include "core/graph/graph.h"
#include "core/graph/graph_utils.h"
#include "core/graph/op.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/optimizer_execution_frame.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensorprotoutils.h"
namespace onnxruntime {
Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
// Store the statically inferred shape of the input to the Shape operator.
const ONNX_NAMESPACE::TensorShapeProto* input_shape_proto = node.InputDefs()[0]->Shape();
std::vector<int64_t> input_dims;
int num_dimensions = input_shape_proto->dim_size();
for (int i = 0; i < num_dimensions; i++) {
input_dims.push_back(gsl::narrow_cast<int64_t>(input_shape_proto->dim(i).dim_value()));
}
// Create the TensorProto that will be used as initializer in place of the Shape operator.
const auto* shape_out_def = node.OutputDefs()[0];
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(shape_out_def->Name());
TensorShape tensor_shape({gsl::narrow_cast<int64_t>(num_dimensions)});
for (auto& dim : tensor_shape.GetDims()) {
shape_initializer_proto.add_dims(dim);
}
auto tensor_proto_data_type = shape_out_def->TypeAsProto()->tensor_type().elem_type();
shape_initializer_proto.set_data_type(tensor_proto_data_type);
// Here we expect little-indian format to set raw data of the TensorProto.
shape_initializer_proto.set_raw_data(input_dims.data(),
input_dims.size() * sizeof(decltype(input_dims)::value_type));
// Remove the output edges of the Shape node, then remove the node itself, and replace it with the initializer.
graph_utils::RemoveNodeOutputEdges(graph, node);
if (graph.RemoveNode(node.Index())) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
graph.AddInitializedTensor(shape_initializer_proto);
}
return Status::OK();
}
bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1}) ||
// Making sure we are not left with a graph with no nodes.
graph.IsNodeOutputsInGraphOutputs(node)) {
return false;
}
// The shape of the input has to be statically known. Moreover, each dimension should have a specific value
// (the rule cannot be applied if one of the dimension is a symbolic variable).
const auto* input_shape = node.InputDefs()[0]->Shape();
if (!input_shape) {
return false;
}
for (int i = 0, num_dims = input_shape->dim_size(); i < num_dims; i++) {
if (!input_shape->dim(i).has_dim_value()) {
return false;
}
}
return true;
}
} // namespace onnxruntime

View file

@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class ShapeToInitializer
When the input to a Shape operator is statically known (through shape inference), this rule replaces the Shape node
with an initializer to the downstream nodes.
It is attempted to be triggered only on nodes with op type "Shape".
*/
class ShapeToInitializer : public RewriteRule {
public:
ShapeToInitializer() noexcept : RewriteRule("ShapeToInitializer") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Shape"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
};
} // namespace onnxruntime

View file

@ -28,6 +28,7 @@
#include "gtest/gtest.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/shape_to_initializer.h"
using namespace std;
using namespace ONNX_NAMESPACE;
@ -65,7 +66,7 @@ TEST(GraphTransformationTests, IdentityElimination) {
ASSERT_TRUE(op_to_count["Identity"] == 0);
}
TEST(GraphTransformationTests, DropoutEliminationSingleOutput) {
TEST(GraphTransformationTests, DropoutElimination) {
string model_uri = MODEL_FOLDER + "dropout.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
@ -107,7 +108,7 @@ TEST(GraphTransformationTests, SliceElimination) {
ASSERT_TRUE(op_to_count["Slice"] == 4);
}
TEST(GraphTransformationTests, ConstantFolding1) {
TEST(GraphTransformationTests, ConstantFolding) {
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
@ -124,6 +125,26 @@ TEST(GraphTransformationTests, ConstantFolding1) {
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
}
TEST(GraphTransformationTests, ShapeToInitializer) {
string model_uri = MODEL_FOLDER + "shape-add.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Shape"] == 3);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
rule_transformer_L1->Register(std::make_unique<ShapeToInitializer>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
op_to_count = CountOpsInGraph(graph);
// One of the Shapes is not eliminated because it inlcludes a symbolic dimension.
ASSERT_TRUE(op_to_count["Shape"] == 1);
}
// Check transformations in the case of a subgraph with constant inputs.
TEST(GraphTransformationTests, SubgraphWithConstantInputs) {
string model_uri = MODEL_FOLDER + "constant-subgraph.onnx";

View file

@ -0,0 +1,42 @@
lotus-transfomrs:â
AC"Shape
BD"Shape

C
DE"Add
EF"Shape

FG"Identity shape-addZ
A



Z
B


N
b
G

j
C

j
D

j
E

j
F

B