From bd2d6af9ca47f21880fe121d3a7c5392f77d9826 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 26 Sep 2019 13:44:33 +1000 Subject: [PATCH] Filter out info from non-const initializers during shape inferencing (#1806) * Don't return shape for non-const initializer in InferenceContextImpl::getInputType Don't return initializer for non-const initializer in InferenceContextImpl::getInputData Update graph_utils to support these scenarios - fix GetConstantInitializer to make sure a name is for an outer scope value before checking a parent graph, as local name could shadow an outer scope initializer. --- .../onnxruntime/core/framework/tensor_shape.h | 8 +- include/onnxruntime/core/graph/graph.h | 5 + onnxruntime/core/framework/tensor_shape.cc | 24 --- onnxruntime/core/framework/utils.cc | 47 ++++++ onnxruntime/core/framework/utils.h | 7 + onnxruntime/core/graph/graph.cc | 142 +++++++++++------- onnxruntime/core/graph/graph_utils.cc | 17 ++- onnxruntime/core/graph/graph_utils.h | 7 +- onnxruntime/test/ir/graph_test.cc | 59 ++++++++ onnxruntime/test/util/compare_ortvalue.cc | 4 +- 10 files changed, 230 insertions(+), 90 deletions(-) diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index c280f61eb1..9a2609bc1f 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -9,10 +9,6 @@ #include #include "onnxruntime_config.h" -namespace ONNX_NAMESPACE { -class TensorShapeProto; -} - namespace onnxruntime { #ifdef __GNUC__ #pragma GCC diagnostic push @@ -142,8 +138,6 @@ class TensorShape : private std::vector { #pragma GCC diagnostic pop #endif // operator<< to nicely output to a stream -std::ostream& operator<<(std::ostream& out, const ::onnxruntime::TensorShape& shape); - -std::ostream& operator<<(std::ostream& out, const ONNX_NAMESPACE::TensorShapeProto& shape_proto); +std::ostream& operator<<(std::ostream& out, const TensorShape& shape); } // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 52fd6f477b..fe57e6d4fa 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -774,6 +774,11 @@ class Graph { /** Returns the Node containing the GraphProto for this Graph instance if IsSubgraph is true */ const Node* ParentNode() const { return parent_node_; } + /** Returns true if the name is for a value that is coming from outer scope */ + bool IsOuterScopeValue(const std::string& name) const { + return resolve_context_.outer_scope_node_args.find(name) != resolve_context_.outer_scope_node_args.cend(); + } + /** Construct a Graph instance for a subgraph that is created from a GraphProto attribute in a Node. Inherits some properties from the parent graph. @param parent_graph The Graph containing the Node that has the GraphProto attribute. diff --git a/onnxruntime/core/framework/tensor_shape.cc b/onnxruntime/core/framework/tensor_shape.cc index 49de886e1d..402714951a 100644 --- a/onnxruntime/core/framework/tensor_shape.cc +++ b/onnxruntime/core/framework/tensor_shape.cc @@ -5,7 +5,6 @@ #include #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" -#include "core/graph/onnx_protobuf.h" namespace onnxruntime { @@ -94,27 +93,4 @@ std::ostream& operator<<(std::ostream& out, const ::onnxruntime::TensorShape& sh return (out << shape.ToString()); } -std::ostream& operator<<(std::ostream& out, const ONNX_NAMESPACE::TensorShapeProto& shape_proto) { - std::string result; - result.reserve(128); - - result.append("{"); - bool first = true; - for (auto& dim : shape_proto.dim()) { - if (!first) { - result.append(","); - } - - if (utils::HasDimValue(dim)) - result.append(std::to_string(dim.dim_value())); - else if (utils::HasDimParam(dim)) - result.append(dim.dim_param()); - - first = false; - } - result.append("}"); - - return (out << result); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index fbb529bd84..5401b08e87 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -16,8 +16,55 @@ #include "core/framework/parallel_executor.h" #include "core/framework/session_state.h" #include "core/framework/sequential_executor.h" +#include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas.h" +#include "core/graph/onnx_protobuf.h" + +namespace ONNX_NAMESPACE { +std::ostream& operator<<(std::ostream& out, const TensorShapeProto& shape_proto) { + std::string result; + result.reserve(128); + + result.append("{"); + bool first = true; + for (auto& dim : shape_proto.dim()) { + if (!first) { + result.append(","); + } + + if (onnxruntime::utils::HasDimValue(dim)) + result.append(std::to_string(dim.dim_value())); + else if (onnxruntime::utils::HasDimParam(dim)) + result.append(dim.dim_param()); + + first = false; + } + result.append("}"); + + return (out << result); +} + +std::ostream& operator<<(std::ostream& out, const TensorProto& tensor_proto) { + std::string result; + result.reserve(128); + + result.append("{"); + bool first = true; + for (auto& dim : tensor_proto.dims()) { + if (!first) { + result.append(","); + } + + result.append(std::to_string(dim)); + first = false; + } + result.append("}"); + + return (out << result); +} +} // namespace ONNX_NAMESPACE + namespace onnxruntime { namespace utils { void* DefaultAlloc(size_t size) { diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 8bf5563c5d..829b39fdd2 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -10,6 +10,13 @@ #include "core/framework/iexecutor.h" #include "core/framework/session_state.h" +namespace ONNX_NAMESPACE { +class TensorShapeProto; +class TensorProto; +std::ostream& operator<<(std::ostream& out, const TensorShapeProto& shape_proto); +std::ostream& operator<<(std::ostream& out, const TensorProto& tensor_proto); +} // namespace ONNX_NAMESPACE + namespace onnxruntime { class ExecutionProviders; struct FeedsFetchesInfo; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 88e0cc3d73..4eb77e0169 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -12,15 +12,20 @@ #include #include "gsl/pointers" +#include "core/common/logging/logging.h" +#include "core/framework/tensor_shape.h" #include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" #include "core/graph/function.h" #include "core/graph/function_impl.h" +#include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/indexed_sub_graph.h" -#include "core/graph/op.h" -#include "core/common/logging/logging.h" -#include "onnx/checker.h" #include "core/graph/schema_registry.h" +#include "core/graph/op.h" + +#include "onnx/checker.h" + using namespace ONNX_NAMESPACE; using namespace ONNX_NAMESPACE::Utils; using namespace ONNX_NAMESPACE::checker; @@ -71,6 +76,16 @@ static void RemoveInvalidValues(ONNX_NAMESPACE::TypeProto& type) { } } +static TypeProto TypeProtoFromTensorProto(const TensorProto& tensor) { + TypeProto t; + t.mutable_tensor_type()->set_elem_type(tensor.data_type()); + auto shape = t.mutable_tensor_type()->mutable_shape(); + for (auto dim : tensor.dims()) + shape->add_dim()->set_dim_value(dim); + + return t; +} + NodeArg::NodeArg(const std::string& name, const TypeProto* p_node_arg_type) { node_arg_info_.set_name(name); // If the name is empty, it means the arg does not exist. @@ -650,8 +665,7 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map } // Copy constant nodes _value to name_to_initial_tensor_ - const gsl::not_null - tensor{graph_proto_->add_initializer()}; + const gsl::not_null tensor{graph_proto_->add_initializer()}; const AttributeProto& constant_attribute = node.attribute(0); // TODO: Add support for parsing 'sparse_value' attribute from a 'Constant' node // Discussion surrounding handling the SparseTensorProto must be had. @@ -671,31 +685,40 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map }), graph_mutable_nodes->end()); + // Collect all node arg name, type, shape information in the graph. + // type/shape information will be assigned to each node arg when going + // thru all nodes later. + + // process graph inputs first as we want the type/shape from them to be preferred if a graph input + // has a matching initializer + for (auto& graph_input : graph_proto_->input()) { + if (utils::HasName(graph_input) && utils::HasType(graph_input)) { + name_to_type_map[graph_input.name()] = graph_input.type(); + GetOrCreateNodeArg(graph_input.name(), &graph_input.type()); + } + } + // Copy initial tensors to a map. for (auto& tensor : graph_proto_->initializer()) { name_to_initial_tensor_[tensor.name()] = &tensor; - // v4 does not require initializers to be inputs, so we need to ensure there is a NodeArg created for all - // initializers in that case - if (ir_version_ > 3) { - TypeProto t; - t.mutable_tensor_type()->set_elem_type(tensor.data_type()); - auto shape = t.mutable_tensor_type()->mutable_shape(); - for (auto dim : tensor.dims()) - shape->add_dim()->set_dim_value(dim); + NodeArg* matching_graph_input = GetNodeArg(tensor.name()); + TypeProto t{TypeProtoFromTensorProto(tensor)}; - GetOrCreateNodeArg(tensor.name(), &t); - } - } - - // Collect all node arg name, type, shape information in the graph. - // type/shape information will be assigned to each node arg when going - // thru all nodes later. - for (auto& graph_input : graph_proto_->input()) { - if (utils::HasName(graph_input) && utils::HasType(graph_input)) { - name_to_type_map[graph_input.name()] = graph_input.type(); - // always create a NodeArg for graph input in case its from an initializer - GetOrCreateNodeArg(graph_input.name(), &graph_input.type()); + if (ir_version_ < 4) { + // initializers can have matching graph inputs but are treated as constant, + // so we prefer the shape from the initializer + name_to_type_map[tensor.name()] = t; + if (matching_graph_input != nullptr) { + ORT_THROW_IF_ERROR(matching_graph_input->UpdateTypeAndShape(t)); + } + } else { + // v4 and later allows a constant initializer with no matching graph input. create a NodeArg for these. + // otherwise we prefer the shape from the graph input so leave matching_graph_input as is. + if (matching_graph_input == nullptr) { + name_to_type_map[tensor.name()] = t; + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor.name(), &t)); + } } } @@ -1227,7 +1250,7 @@ class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer { // Perform inferencing on the graph contained in GraphInferencer. // Returns the graph output types post-inferencing. - // We ignore input_data currently. Re-consider if InferenceContextImpl::getInputData gets implemented + // We ignore input_data currently as the inferencing happens prior to receiving user input. std::vector doInferencing(const std::vector& input_types, const std::vector& /*input_data*/) override { std::vector output_types; @@ -1255,10 +1278,10 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { public: InferenceContextImpl(Node& node, SubgraphInferencingFunc subgraph_inferencing_func, - const InitializedTensorSet& initialized_tensor_set = {}) noexcept + const Graph& graph) noexcept : node_(node), subgraph_inferencing_func_(subgraph_inferencing_func), - initialized_tensor_set_(initialized_tensor_set) { + graph_(graph) { node_output_types_.resize(node.OutputDefs().size()); } @@ -1285,15 +1308,13 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { } const TypeProto* getInputType(size_t index) const override { + const TypeProto* type = nullptr; auto p_node_arg = node_.InputDefs().at(index); if ((nullptr != p_node_arg) && p_node_arg->Exists()) { - return p_node_arg->TypeAsProto(); - // auto p_type_proto = p_node_arg->TypeAsProto(); - //if ((p_type_proto != nullptr) && p_type_proto->has_tensor_type()) { - // return &p_type_proto->tensor_type(); - //} + type = p_node_arg->TypeAsProto(); } - return nullptr; + + return type; } size_t getNumOutputs() const noexcept override { @@ -1308,9 +1329,11 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { auto def = node_.InputDefs()[index]; if (!def) return nullptr; - if (initialized_tensor_set_.count(def->Name()) == 0) - return nullptr; - return initialized_tensor_set_.at(def->Name()); + + // only return data if it's for a constant initializer. checks for outer scope initializers + // if this is a subgraph and the name isn't found locally. + const TensorProto* initializer = graph_utils::GetConstantInitializer(graph_, def->Name(), true); + return initializer; } GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) override { @@ -1336,7 +1359,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { std::vector node_output_types_; SubgraphInferencingFunc subgraph_inferencing_func_; std::vector> graph_inferencers_; - const InitializedTensorSet& initialized_tensor_set_; + const Graph& graph_; }; Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, @@ -1512,7 +1535,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op) { // Once that completes, the outputs from the node containing the subgraph will be updated, and the final values // returned here. SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes); - InferenceContextImpl context(node, func, name_to_initial_tensor_); + InferenceContextImpl context(node, func, *this); try { context.RunInferencing(); @@ -1611,9 +1634,6 @@ common::Status Graph::TypeCheckInputsAndInitializers() { } } - // Note: The ONNX spec requires every initializer to be included in the graph input, - // but onnxruntime relaxes this requirement for various reasons. - // Infer/check type and shape for all initializers from their values for (auto& initializer_pair : name_to_initial_tensor_) { const std::string& name = initializer_pair.first; @@ -1637,22 +1657,33 @@ common::Status Graph::TypeCheckInputsAndInitializers() { for (auto dim : tensor_proto->dims()) { inferred_shape.add_dim()->set_dim_value(dim); } + const TensorShapeProto* p_existing_shape = node_arg->Shape(); - if (nullptr == p_existing_shape) - node_arg->SetShape(inferred_shape); - else { - if (p_existing_shape->dim_size() != tensor_proto->dims_size()) - return Status(ONNXRUNTIME, FAIL, - "Type Error: Shape of initializer " + name + " does not match its type."); + if (nullptr == p_existing_shape) { + // use the inferred shape if this is a constant initializer (cannot be overridden). + // if not it has a matching graph input, and we prefer the shape info (or lack of info) from the graph input + if (graph_utils::IsConstantInitializer(*this, name, false)) { + node_arg->SetShape(inferred_shape); + } + } else { + if (p_existing_shape->dim_size() != tensor_proto->dims_size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Type Error: Shape of initializer ", name, " does not match. ", + *p_existing_shape, " != ", *tensor_proto); + } + for (int i = 0; i < p_existing_shape->dim_size(); ++i) { auto& d = p_existing_shape->dim(i); - if (utils::HasDimValue(d) && (d.dim_value() != tensor_proto->dims(i))) - return Status(ONNXRUNTIME, FAIL, - "Type Error: Shape of initializer " + initializer_pair.first + " does not match its type."); + if (utils::HasDimValue(d) && (d.dim_value() != tensor_proto->dims(i))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Type Error: Shape of initializer ", name, " does not match. ", + *p_existing_shape, " != ", *tensor_proto); + } } } } } + return Status::OK(); } @@ -1943,13 +1974,12 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { *(tensor_added) = tensor; name_to_initial_tensor_[tensor.name()] = tensor_added; - if (!GraphLoadedFromModelFile(graph_proto_)) { - // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs will add it to the graph inputs + if (!GraphLoadedFromModelFile(graph_proto_) && GetNodeArg(tensor.name()) == nullptr) { + // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. + // the shape will be set to the correct value in TypeCheckInputsAndInitializers as we don't yet know whether there + // will be a matching graph input for this initializer (we prefer shape info from the graph input). TypeProto t; t.mutable_tensor_type()->set_elem_type(tensor.data_type()); - auto shape = t.mutable_tensor_type()->mutable_shape(); - for (auto dim : tensor.dims()) - shape->add_dim()->set_dim_value(dim); ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor.name(), &t)); } diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 2ac2a15303..2ea4fef610 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -391,12 +391,27 @@ const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const Graph& graph, co } } } else if (check_outer_scope && graph.IsSubgraph()) { - initializer = GetConstantInitializer(*graph.ParentGraph(), initializer_name); + // make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope. + if (graph.IsOuterScopeValue(initializer_name)) { + initializer = GetConstantInitializer(*graph.ParentGraph(), initializer_name, check_outer_scope); + } } return initializer; } +bool IsInitializer(const Graph& graph, const std::string& name, bool check_outer_scope) { + bool is_initializer = false; + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph.GetInitializedTensor(name, initializer)) { + is_initializer = true; + } else if (check_outer_scope && graph.IsSubgraph() && graph.IsOuterScopeValue(name)) { + is_initializer = IsInitializer(*graph.ParentGraph(), name, check_outer_scope); + } + + return is_initializer; +} + bool IsConstantInitializer(const Graph& graph, const std::string& initializer_name, bool check_outer_scope) { const ONNX_NAMESPACE::TensorProto* initializer = GetConstantInitializer(graph, initializer_name, check_outer_scope); return initializer != nullptr; diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 9d60e70fbd..1df7cb44ed 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -38,8 +38,13 @@ bool IsOutputUsed(const Node& node, int index); /** Returns true if the graph has the given input.*/ bool IsGraphInput(const Graph& graph, const NodeArg* input); +/** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true. +@param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. +*/ +bool IsInitializer(const Graph& graph, const std::string& name, bool check_outer_scope); + /** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime. -@param check_outer_scope If true and the graph is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. +@param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. */ bool IsConstantInitializer(const Graph& graph, const std::string& name, bool check_outer_scope = true); diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 70f2bfe306..739351810f 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -814,6 +814,65 @@ TEST(TypeInferenceTest, VariadicOutput) { CheckTensorEltType(Z.TypeAsProto(), TensorProto_DataType_FLOAT); } +// test that we prefer the graph input shape for a non-const initializer (initializer with matching graph input) +TEST(TypeInferenceTest, NonConstInitializer) { + Model model("graph_1"); + auto& graph = model.MainGraph(); + + TypeProto tensor_type_no_shape; + tensor_type_no_shape.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + // tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + auto& X = graph.GetOrCreateNodeArg("X", &tensor_type_no_shape); + auto& Y = graph.GetOrCreateNodeArg("Y_Initializer", &tensor_type_no_shape); + auto& Z = graph.GetOrCreateNodeArg("Z", nullptr); + + // 2 graph inputs, both without shapes + graph.SetInputs({&X, &Y}); + + // add initializer for the Y input with shape + TensorProto t; + t.set_data_type(TensorProto_DataType_FLOAT); + t.add_float_data(0.1f); + t.add_float_data(0.2f); + t.add_dims(2); + t.set_name("Y_Initializer"); + graph.AddInitializedTensor(t); + + graph.AddNode("node_1", "Add", "node 1.", {&X, &Y}, {&Z}); + + auto resolve_and_validate = [](Graph& g) { + auto status = g.Resolve(); + EXPECT_TRUE(status.IsOK()) << status; + + const auto* current_Y = g.GetNodeArg("Y_Initializer"); + const auto* current_Z = g.GetNodeArg("Z"); + + // the graph input should still have no shape as we don't want to infer the shape from the initializer + // as inputs have priority + EXPECT_TRUE(current_Y != nullptr && current_Y->Shape() == nullptr); + + // and we should have type but no shape for Z after type/shape inferencing + EXPECT_TRUE(current_Z != nullptr && current_Z->Type() == current_Y->Type()); + EXPECT_TRUE(current_Z->Shape() == nullptr); + }; + + resolve_and_validate(graph); + + // save and reload to validate same happens when graph is loaded from proto + std::string s1; + ModelProto model_proto; + std::shared_ptr p_model; + ASSERT_TRUE(model.ToProto().SerializeToString(&s1)); + ASSERT_TRUE(model_proto.ParseFromString(s1)); + + auto status = onnxruntime::Model::Load(model_proto, p_model, nullptr); + ASSERT_TRUE(status.IsOK()) << status; + + auto& graph2 = p_model->MainGraph(); + resolve_and_validate(graph2); +} + // Test that Graph::Resolve identifies name-duplication across initializer and node-output-arg TEST(NameResolutionTest, DuplicateName) { Model model("graph_1"); diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index c526832d23..52d3f140c1 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -13,6 +13,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" #include "Eigen/Core" #include "Eigen/src/Core/arch/GPU/Half.h" @@ -297,7 +298,8 @@ bool AreShapesEqual(const std::vector& real_shape, const ::ONNX_NAMESPA continue; break; // This is for unlikely case when we add new oneof value - default : assert(false); + default: + assert(false); break; } }