Create NodeArg for all initializers if IR version is > 3. (#742)

Previously all initializers had to have matching graph inputs and the NodeArg was guaranteed to be created via graph input processing.
This commit is contained in:
Scott McKay 2019-04-05 14:09:27 +10:00 committed by GitHub
parent 2674d9bd8a
commit 65c50bb25b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 0 deletions

View file

@ -658,6 +658,18 @@ Graph::Graph(GraphProto* graph_proto,
// Copy initial tensors to a map.
for (auto& tensor : graph_proto_->initializer()) {
name_to_initial_tensor_[tensor.name()] = &tensor;
// v4 does not require initializers to be inputs, so we need to ensure there is a NodeArg created for all
// initializers in that case
if (ir_version > 3) {
TypeProto t;
t.mutable_tensor_type()->set_elem_type(tensor.data_type());
auto shape = t.mutable_tensor_type()->mutable_shape();
for (auto dim : tensor.dims())
shape->add_dim()->set_dim_value(dim);
GetOrCreateNodeArg(tensor.name(), &t);
}
}
// Collect all node arg name, type, shape information in the graph.

View file

@ -146,5 +146,15 @@ INSTANTIATE_TEST_CASE_P(ONNXModelsTests,
::testing::Values("bvlc_alexnet", "bvlc_googlenet", "bvlc_reference_caffenet", "bvlc_reference_rcnn_ilsvrc13", "densenet121", "emotion_ferplus", "inception_v1", "inception_v2", "mnist", "resnet50", "shufflenet", "squeezenet", "tiny_yolov2", "vgg19", "zfnet512"));
#endif
// test a model that conforms to ONNX IR v4 where there are initializers that are not graph inputs.
// a NodeArg should be created for all initializers in this case.
// the test model contains initializers that are used as implicit inputs in a subgraph, and the NodeArg is required
// for Graph::Resolve to succeed when processing the subgraph.
TEST(ONNXModelsTest, TestIRv4NonInputInitializers) {
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load("testdata/subgraph_implicit_input_from_initializer.onnx", model).IsOK());
EXPECT_TRUE(model->MainGraph().Resolve().IsOK());
}
} // namespace test
} // namespace onnxruntime