From 65c50bb25b2f5a77572fae7a3142ecc07088c84a Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 5 Apr 2019 14:09:27 +1000 Subject: [PATCH] Create NodeArg for all initializers if IR version is > 3. (#742) Previously all initializers had to have matching graph inputs and the NodeArg was guaranteed to be created via graph input processing. --- onnxruntime/core/graph/graph.cc | 12 ++++++++++++ onnxruntime/test/ir/onnx_model_test.cc | 10 ++++++++++ .../subgraph_implicit_input_from_initializer.onnx | Bin 0 -> 365 bytes 3 files changed, 22 insertions(+) create mode 100644 onnxruntime/test/testdata/subgraph_implicit_input_from_initializer.onnx diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 23a3071264..528fda4851 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -658,6 +658,18 @@ Graph::Graph(GraphProto* graph_proto, // Copy initial tensors to a map. for (auto& tensor : graph_proto_->initializer()) { name_to_initial_tensor_[tensor.name()] = &tensor; + + // v4 does not require initializers to be inputs, so we need to ensure there is a NodeArg created for all + // initializers in that case + if (ir_version > 3) { + TypeProto t; + t.mutable_tensor_type()->set_elem_type(tensor.data_type()); + auto shape = t.mutable_tensor_type()->mutable_shape(); + for (auto dim : tensor.dims()) + shape->add_dim()->set_dim_value(dim); + + GetOrCreateNodeArg(tensor.name(), &t); + } } // Collect all node arg name, type, shape information in the graph. diff --git a/onnxruntime/test/ir/onnx_model_test.cc b/onnxruntime/test/ir/onnx_model_test.cc index 8c39718778..4e832d6628 100644 --- a/onnxruntime/test/ir/onnx_model_test.cc +++ b/onnxruntime/test/ir/onnx_model_test.cc @@ -146,5 +146,15 @@ INSTANTIATE_TEST_CASE_P(ONNXModelsTests, ::testing::Values("bvlc_alexnet", "bvlc_googlenet", "bvlc_reference_caffenet", "bvlc_reference_rcnn_ilsvrc13", "densenet121", "emotion_ferplus", "inception_v1", "inception_v2", "mnist", "resnet50", "shufflenet", "squeezenet", "tiny_yolov2", "vgg19", "zfnet512")); #endif + +// test a model that conforms to ONNX IR v4 where there are initializers that are not graph inputs. +// a NodeArg should be created for all initializers in this case. +// the test model contains initializers that are used as implicit inputs in a subgraph, and the NodeArg is required +// for Graph::Resolve to succeed when processing the subgraph. +TEST(ONNXModelsTest, TestIRv4NonInputInitializers) { + std::shared_ptr model; + ASSERT_TRUE(Model::Load("testdata/subgraph_implicit_input_from_initializer.onnx", model).IsOK()); + EXPECT_TRUE(model->MainGraph().Resolve().IsOK()); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/subgraph_implicit_input_from_initializer.onnx b/onnxruntime/test/testdata/subgraph_implicit_input_from_initializer.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9e5d44d37fbe643bab5a494b9a27f97858d2892e GIT binary patch literal 365 zcmZ{g!3u&v6h&tSv*{HjLl8&{acQ*BWt+lXU(j-Bbj*R$&~eeOPwA`rjJ_ca!H`-# z9``=ZeU}<+B68vAR62<&)6J8>3xgTuSy?5D5v6hO?iE0n<|zW4_vyhr&jvLpmT^N^!yBfQp^k>lsxdDLDiwMA1I$q zKT03vf1^4_gX*>>S`WoQ%Tli|c`VP~jACo#VHCC%cneA_>{5j8TZGUqZ2$HD34xeH LTiZ7$Bo@2@gy3P8 literal 0 HcmV?d00001