From 93e239747f3073cb3a1ff41381feea7433fcbbee Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Thu, 11 Nov 2021 15:13:28 -0800 Subject: [PATCH] Construct valid graphs for ONNX checker for IR version < 4. (#9665) * Construct valid graphs for ONNX checker for IR version < 4. Previously the constructed graph was not guaranteed to have its initializers be a subset of its inputs, which is required for IR version < 4. This resulted in spurious failures. Fixes #9663 --- include/onnxruntime/core/graph/node_arg.h | 2 +- onnxruntime/core/graph/graph.cc | 6 ++ onnxruntime/test/ir/graph_test.cc | 72 +++++++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index a766213c92..ecb5bdc131 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -99,7 +99,7 @@ class NodeArg { #endif // !defined(ORT_MINIMAL_BUILD) - /** Gets this NodeArg as a ValueInfoProto. */ + /** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */ const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } /** Gets a flag indicating whether this NodeArg exists or not. diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index eba88da49e..761ea6f2f1 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1104,6 +1104,12 @@ Graph::Graph(const Model& owning_model, const gsl::not_null tensor{graph_proto_->add_initializer()}; auto status = utils::ConstantNodeProtoToTensorProto(node, model_path, *tensor); ORT_ENFORCE(status.IsOK(), status.ToString()); + // Ensure initializers are also graph inputs. + if (ir_version_ < 4) { + TypeProto t{TypeProtoFromTensorProto(*tensor)}; + const NodeArg& node_arg = GetOrCreateNodeArg(tensor->name(), &t); + *(graph_proto_->add_input()) = node_arg.ToProto(); + } #if !defined(DISABLE_SPARSE_TENSORS) if (node.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { auto p = sparse_tensor_names_.emplace(tensor->name()); diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 5faae75209..a52549e0f8 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1961,5 +1961,77 @@ TEST_F(GraphTest, DontRemoveUnusedInitializerWithGraphInput) { ASSERT_NE(j, inputs_including_initializers.cend()) << "Unused initializer was incorrectly removed."; } + +// The Model class internally: +// 1. Converts the ONNX Model to an ORT Model +// 2. Converts the ORT Model's Graph back to an ONNX Graph so that it +// can run the ONNX Checker on it. +// Previously this was buggy for models containing subgraphs with IR Version < 4. +// We didn't always ensure that initializers are a subset of inputs, which is +// required for IR version < 4. +TEST_F(GraphTest, ConstantsBecomeInitializersAndInputs) { + ModelProto m; + m.set_ir_version(ONNX_NAMESPACE::IR_VERSION_2017_11_3); + ImportOpset(m, "", 13); + GraphProto* g = m.mutable_graph(); + g->set_name("test"); + + // Construct "output = if x: 2.0 else 1.0" + ValueInfoProto* x = g->add_input(); + x->set_name("x"); + SetTypeAndShape(x->mutable_type()->mutable_tensor_type(), TensorProto_DataType_BOOL, {1}); + + NodeProto* if_node = g->add_node(); + if_node->set_op_type("If"); + if_node->set_name("If"); + if_node->add_input("x"); + if_node->add_output("output"); + + AttributeProto* if_attr = if_node->add_attribute(); + if_attr->set_name("then_branch"); + if_attr->set_type(AttributeProto_AttributeType_GRAPH); + GraphProto* then_g = if_attr->mutable_g(); + then_g->set_name("then"); + ValueInfoProto* then_out = then_g->add_output(); + then_out->set_name("then_out"); + SetTypeAndShape(then_out->mutable_type()->mutable_tensor_type(), TensorProto_DataType_FLOAT, {1}); + NodeProto* two_node = then_g->add_node(); + two_node->set_op_type("Constant"); + AttributeProto* two_attr = two_node->add_attribute(); + two_attr->set_name("value"); + two_attr->set_type(AttributeProto_AttributeType_TENSOR); + two_attr->mutable_t()->add_float_data(2.0); + two_attr->mutable_t()->set_data_type(TensorProto_DataType_FLOAT); + two_attr->mutable_t()->add_dims(1); + two_node->set_name("Constant_two"); + two_node->add_output("then_out"); + + AttributeProto* else_attr = if_node->add_attribute(); + else_attr->set_name("else_branch"); + else_attr->set_type(AttributeProto_AttributeType_GRAPH); + GraphProto* else_g = else_attr->mutable_g(); + else_g->set_name("else"); + ValueInfoProto* else_out = else_g->add_output(); + else_out->set_name("else_out"); + SetTypeAndShape(else_out->mutable_type()->mutable_tensor_type(), TensorProto_DataType_FLOAT, {1}); + NodeProto* one_node = else_g->add_node(); + one_node->set_op_type("Constant"); + AttributeProto* one_attr = one_node->add_attribute(); + one_attr->set_name("value"); + one_attr->set_type(AttributeProto_AttributeType_TENSOR); + one_attr->mutable_t()->add_float_data(1.0); + one_attr->mutable_t()->set_data_type(TensorProto_DataType_FLOAT); + one_attr->mutable_t()->add_dims(1); + one_node->set_name("Constant_one"); + one_node->add_output("else_out"); + + ValueInfoProto* output = g->add_output(); + output->set_name("output"); + SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), TensorProto_DataType_FLOAT, {1}); + + std::shared_ptr model; + Status st = Model::Load(std::move(m), model, nullptr, *logger_); + ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); +} } // namespace test } // namespace onnxruntime