mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
Create NodeArg for all initializers if IR version is > 3. (#742)
Previously all initializers had to have matching graph inputs and the NodeArg was guaranteed to be created via graph input processing.
This commit is contained in:
parent
2674d9bd8a
commit
65c50bb25b
3 changed files with 22 additions and 0 deletions
|
|
@ -658,6 +658,18 @@ Graph::Graph(GraphProto* graph_proto,
|
|||
// Copy initial tensors to a map.
|
||||
for (auto& tensor : graph_proto_->initializer()) {
|
||||
name_to_initial_tensor_[tensor.name()] = &tensor;
|
||||
|
||||
// v4 does not require initializers to be inputs, so we need to ensure there is a NodeArg created for all
|
||||
// initializers in that case
|
||||
if (ir_version > 3) {
|
||||
TypeProto t;
|
||||
t.mutable_tensor_type()->set_elem_type(tensor.data_type());
|
||||
auto shape = t.mutable_tensor_type()->mutable_shape();
|
||||
for (auto dim : tensor.dims())
|
||||
shape->add_dim()->set_dim_value(dim);
|
||||
|
||||
GetOrCreateNodeArg(tensor.name(), &t);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all node arg name, type, shape information in the graph.
|
||||
|
|
|
|||
|
|
@ -146,5 +146,15 @@ INSTANTIATE_TEST_CASE_P(ONNXModelsTests,
|
|||
::testing::Values("bvlc_alexnet", "bvlc_googlenet", "bvlc_reference_caffenet", "bvlc_reference_rcnn_ilsvrc13", "densenet121", "emotion_ferplus", "inception_v1", "inception_v2", "mnist", "resnet50", "shufflenet", "squeezenet", "tiny_yolov2", "vgg19", "zfnet512"));
|
||||
|
||||
#endif
|
||||
|
||||
// test a model that conforms to ONNX IR v4 where there are initializers that are not graph inputs.
|
||||
// a NodeArg should be created for all initializers in this case.
|
||||
// the test model contains initializers that are used as implicit inputs in a subgraph, and the NodeArg is required
|
||||
// for Graph::Resolve to succeed when processing the subgraph.
|
||||
TEST(ONNXModelsTest, TestIRv4NonInputInitializers) {
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load("testdata/subgraph_implicit_input_from_initializer.onnx", model).IsOK());
|
||||
EXPECT_TRUE(model->MainGraph().Resolve().IsOK());
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/subgraph_implicit_input_from_initializer.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/subgraph_implicit_input_from_initializer.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue