diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 321c48025b..905affee11 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3168,13 +3168,21 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const void Graph::CleanUnusedInitializers(const std::unordered_set* initializer_names_to_preserve) { std::unordered_set used_args; + // anything that provides a required graph input (GetInputs), an optional graph input (GetOverridableInitializers) + // or a graph output (GetOutputs) cannot be removed const auto& inputs = GetInputs(); + const auto& overridable_initializers = GetOverridableInitializers(); const auto& outputs = GetOutputs(); std::for_each(inputs.cbegin(), inputs.cend(), [&used_args](const NodeArg* input) { ORT_IGNORE_RETURN_VALUE(used_args.insert(input->Name())); }); + std::for_each(overridable_initializers.cbegin(), overridable_initializers.cend(), + [&used_args](const NodeArg* input) { + ORT_IGNORE_RETURN_VALUE(used_args.insert(input->Name())); + }); + std::for_each(outputs.cbegin(), outputs.cend(), [&used_args](const NodeArg* output) { ORT_IGNORE_RETURN_VALUE(used_args.insert(output->Name())); }); @@ -3214,23 +3222,28 @@ void Graph::CleanUnusedInitializers(const std::unordered_set* initi [this](const std::string& name) { RemoveInitializedTensor(name); - // handle edge case where the unused initializer has a matching graph input - auto& proto_inputs = *graph_proto_->mutable_input(); - auto i = std::find_if(proto_inputs.begin(), proto_inputs.end(), - [&name](const ONNX_NAMESPACE::ValueInfoProto& input) { - return input.name() == name; - }); + // handle edge case where the unused initializer has a matching graph input. + // this can only happen when initializers cannot be overridden via an optional graph input. + // (otherwise this initializer wouldn't be allowed to be removed due to it backing an optional + // graph input). + if (CanOverrideInitializer() == false) { + auto& proto_inputs = *graph_proto_->mutable_input(); + auto i = std::find_if(proto_inputs.begin(), proto_inputs.end(), + [&name](const ONNX_NAMESPACE::ValueInfoProto& input) { + return input.name() == name; + }); - if (i != proto_inputs.end()) { - RemoveRepeatedFieldEntry(proto_inputs, i); - } + if (i != proto_inputs.end()) { + RemoveRepeatedFieldEntry(proto_inputs, i); + } - auto& inputs_including_initializers = graph_inputs_including_initializers_; - auto j = std::find_if(inputs_including_initializers.begin(), inputs_including_initializers.end(), - [&name](const NodeArg* input) { return input->Name() == name; }); + auto& inputs_including_initializers = graph_inputs_including_initializers_; + auto j = std::find_if(inputs_including_initializers.begin(), inputs_including_initializers.end(), + [&name](const NodeArg* input) { return input->Name() == name; }); - if (j != inputs_including_initializers.end()) { - inputs_including_initializers.erase(j); + if (j != inputs_including_initializers.end()) { + inputs_including_initializers.erase(j); + } } }); } diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index a29f4adb95..75f5933505 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -932,7 +932,7 @@ TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompr node_9 (Merge) | */ - + TypeProto tensor_int32; tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); @@ -962,10 +962,10 @@ TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompr auto& compress_node1 = graph.AddNode("compress_1", "Identity_Fake", "compress node 1", {&output_arg1}, {&output_arg5}); compress_node1.SetPriority(static_cast(ExecutionPriority::LOCAL_HIGH)); - + auto& decompress_node1 = graph.AddNode("decompress_1", "Identity_Fake", "decompress node 1", {&output_arg5}, {&output_arg6}); - decompress_node1.SetPriority(10); // lower number means high priority - + decompress_node1.SetPriority(10); // lower number means high priority + graph.AddNode("node_7", "Identity_Fake", "node 7", {&output_arg4}, {&output_arg7}); graph.AddNode("node_8", "Merge_Fake", "node 8", {&output_arg7, &output_arg6}, {&output_arg8}); graph.AddNode("node_9", "Merge_Fake", "node 9", {&output_arg8, &output_arg3}, {&output_arg9}); @@ -1908,5 +1908,21 @@ TEST_F(GraphTest, LoadModelMissingInput) { "initializer, or output of a previous node.")); } +// if an initializer is backing an optional graph input, it can't be removed even if unused in the graph. +TEST_F(GraphTest, DontRemoveUnusedInitializerWithGraphInput) { + const std::string unused_initializer_name("truncation:0"); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ORT_TSTR("testdata/unused_initializer.onnx"), model, nullptr, *logger_)); + + auto& graph = model->MainGraph(); + const auto& inputs_including_initializers = graph.GetInputsIncludingInitializers(); + auto j = std::find_if(inputs_including_initializers.cbegin(), inputs_including_initializers.cend(), + [&unused_initializer_name](const NodeArg* input) { + return input->Name() == unused_initializer_name; + }); + + ASSERT_NE(j, inputs_including_initializers.cend()) << "Unused initializer was incorrectly removed."; +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/unused_initializer.onnx b/onnxruntime/test/testdata/unused_initializer.onnx new file mode 100644 index 0000000000..2fc76ba767 Binary files /dev/null and b/onnxruntime/test/testdata/unused_initializer.onnx differ