From 8deca24b1a66bd97b3593a3f02fab47a809e0ded Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 12 May 2021 07:03:54 +1000 Subject: [PATCH] Don't remove an unused initializer if it is overridable. (#7649) --- onnxruntime/core/graph/graph.cc | 41 ++++++++++++------ onnxruntime/test/ir/graph_test.cc | 24 ++++++++-- .../test/testdata/unused_initializer.onnx | Bin 0 -> 145 bytes 3 files changed, 47 insertions(+), 18 deletions(-) create mode 100644 onnxruntime/test/testdata/unused_initializer.onnx diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 321c48025b..905affee11 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3168,13 +3168,21 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const void Graph::CleanUnusedInitializers(const std::unordered_set* initializer_names_to_preserve) { std::unordered_set used_args; + // anything that provides a required graph input (GetInputs), an optional graph input (GetOverridableInitializers) + // or a graph output (GetOutputs) cannot be removed const auto& inputs = GetInputs(); + const auto& overridable_initializers = GetOverridableInitializers(); const auto& outputs = GetOutputs(); std::for_each(inputs.cbegin(), inputs.cend(), [&used_args](const NodeArg* input) { ORT_IGNORE_RETURN_VALUE(used_args.insert(input->Name())); }); + std::for_each(overridable_initializers.cbegin(), overridable_initializers.cend(), + [&used_args](const NodeArg* input) { + ORT_IGNORE_RETURN_VALUE(used_args.insert(input->Name())); + }); + std::for_each(outputs.cbegin(), outputs.cend(), [&used_args](const NodeArg* output) { ORT_IGNORE_RETURN_VALUE(used_args.insert(output->Name())); }); @@ -3214,23 +3222,28 @@ void Graph::CleanUnusedInitializers(const std::unordered_set* initi [this](const std::string& name) { RemoveInitializedTensor(name); - // handle edge case where the unused initializer has a matching graph input - auto& proto_inputs = *graph_proto_->mutable_input(); - auto i = std::find_if(proto_inputs.begin(), proto_inputs.end(), - [&name](const ONNX_NAMESPACE::ValueInfoProto& input) { - return input.name() == name; - }); + // handle edge case where the unused initializer has a matching graph input. + // this can only happen when initializers cannot be overridden via an optional graph input. + // (otherwise this initializer wouldn't be allowed to be removed due to it backing an optional + // graph input). + if (CanOverrideInitializer() == false) { + auto& proto_inputs = *graph_proto_->mutable_input(); + auto i = std::find_if(proto_inputs.begin(), proto_inputs.end(), + [&name](const ONNX_NAMESPACE::ValueInfoProto& input) { + return input.name() == name; + }); - if (i != proto_inputs.end()) { - RemoveRepeatedFieldEntry(proto_inputs, i); - } + if (i != proto_inputs.end()) { + RemoveRepeatedFieldEntry(proto_inputs, i); + } - auto& inputs_including_initializers = graph_inputs_including_initializers_; - auto j = std::find_if(inputs_including_initializers.begin(), inputs_including_initializers.end(), - [&name](const NodeArg* input) { return input->Name() == name; }); + auto& inputs_including_initializers = graph_inputs_including_initializers_; + auto j = std::find_if(inputs_including_initializers.begin(), inputs_including_initializers.end(), + [&name](const NodeArg* input) { return input->Name() == name; }); - if (j != inputs_including_initializers.end()) { - inputs_including_initializers.erase(j); + if (j != inputs_including_initializers.end()) { + inputs_including_initializers.erase(j); + } } }); } diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index a29f4adb95..75f5933505 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -932,7 +932,7 @@ TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompr node_9 (Merge) | */ - + TypeProto tensor_int32; tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); @@ -962,10 +962,10 @@ TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompr auto& compress_node1 = graph.AddNode("compress_1", "Identity_Fake", "compress node 1", {&output_arg1}, {&output_arg5}); compress_node1.SetPriority(static_cast(ExecutionPriority::LOCAL_HIGH)); - + auto& decompress_node1 = graph.AddNode("decompress_1", "Identity_Fake", "decompress node 1", {&output_arg5}, {&output_arg6}); - decompress_node1.SetPriority(10); // lower number means high priority - + decompress_node1.SetPriority(10); // lower number means high priority + graph.AddNode("node_7", "Identity_Fake", "node 7", {&output_arg4}, {&output_arg7}); graph.AddNode("node_8", "Merge_Fake", "node 8", {&output_arg7, &output_arg6}, {&output_arg8}); graph.AddNode("node_9", "Merge_Fake", "node 9", {&output_arg8, &output_arg3}, {&output_arg9}); @@ -1908,5 +1908,21 @@ TEST_F(GraphTest, LoadModelMissingInput) { "initializer, or output of a previous node.")); } +// if an initializer is backing an optional graph input, it can't be removed even if unused in the graph. +TEST_F(GraphTest, DontRemoveUnusedInitializerWithGraphInput) { + const std::string unused_initializer_name("truncation:0"); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ORT_TSTR("testdata/unused_initializer.onnx"), model, nullptr, *logger_)); + + auto& graph = model->MainGraph(); + const auto& inputs_including_initializers = graph.GetInputsIncludingInitializers(); + auto j = std::find_if(inputs_including_initializers.cbegin(), inputs_including_initializers.cend(), + [&unused_initializer_name](const NodeArg* input) { + return input->Name() == unused_initializer_name; + }); + + ASSERT_NE(j, inputs_including_initializers.cend()) << "Unused initializer was incorrectly removed."; +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/unused_initializer.onnx b/onnxruntime/test/testdata/unused_initializer.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2fc76ba767c84f5f40b1b9a5333a210a24366420 GIT binary patch literal 145 zcmd;J7h*3-Gs@4)tB_(f)U(htuqua(Xo(3hI`NbgmF6WUmSpDVSs8e-FfcUON6B(A zS6UeeiE)W=FbWB9aS3sh=4Hpn8<`t&F>$a)iE*K55@O?G0jgt2!mrJVg^NLe7XWIO B9eDr% literal 0 HcmV?d00001