mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Don't remove an unused initializer if it is overridable. (#7649)
This commit is contained in:
parent
c5aeaa9419
commit
8deca24b1a
3 changed files with 47 additions and 18 deletions
|
|
@ -3168,13 +3168,21 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const
|
|||
void Graph::CleanUnusedInitializers(const std::unordered_set<std::string>* initializer_names_to_preserve) {
|
||||
std::unordered_set<std::string> used_args;
|
||||
|
||||
// anything that provides a required graph input (GetInputs), an optional graph input (GetOverridableInitializers)
|
||||
// or a graph output (GetOutputs) cannot be removed
|
||||
const auto& inputs = GetInputs();
|
||||
const auto& overridable_initializers = GetOverridableInitializers();
|
||||
const auto& outputs = GetOutputs();
|
||||
|
||||
std::for_each(inputs.cbegin(), inputs.cend(), [&used_args](const NodeArg* input) {
|
||||
ORT_IGNORE_RETURN_VALUE(used_args.insert(input->Name()));
|
||||
});
|
||||
|
||||
std::for_each(overridable_initializers.cbegin(), overridable_initializers.cend(),
|
||||
[&used_args](const NodeArg* input) {
|
||||
ORT_IGNORE_RETURN_VALUE(used_args.insert(input->Name()));
|
||||
});
|
||||
|
||||
std::for_each(outputs.cbegin(), outputs.cend(), [&used_args](const NodeArg* output) {
|
||||
ORT_IGNORE_RETURN_VALUE(used_args.insert(output->Name()));
|
||||
});
|
||||
|
|
@ -3214,23 +3222,28 @@ void Graph::CleanUnusedInitializers(const std::unordered_set<std::string>* initi
|
|||
[this](const std::string& name) {
|
||||
RemoveInitializedTensor(name);
|
||||
|
||||
// handle edge case where the unused initializer has a matching graph input
|
||||
auto& proto_inputs = *graph_proto_->mutable_input();
|
||||
auto i = std::find_if(proto_inputs.begin(), proto_inputs.end(),
|
||||
[&name](const ONNX_NAMESPACE::ValueInfoProto& input) {
|
||||
return input.name() == name;
|
||||
});
|
||||
// handle edge case where the unused initializer has a matching graph input.
|
||||
// this can only happen when initializers cannot be overridden via an optional graph input.
|
||||
// (otherwise this initializer wouldn't be allowed to be removed due to it backing an optional
|
||||
// graph input).
|
||||
if (CanOverrideInitializer() == false) {
|
||||
auto& proto_inputs = *graph_proto_->mutable_input();
|
||||
auto i = std::find_if(proto_inputs.begin(), proto_inputs.end(),
|
||||
[&name](const ONNX_NAMESPACE::ValueInfoProto& input) {
|
||||
return input.name() == name;
|
||||
});
|
||||
|
||||
if (i != proto_inputs.end()) {
|
||||
RemoveRepeatedFieldEntry(proto_inputs, i);
|
||||
}
|
||||
if (i != proto_inputs.end()) {
|
||||
RemoveRepeatedFieldEntry(proto_inputs, i);
|
||||
}
|
||||
|
||||
auto& inputs_including_initializers = graph_inputs_including_initializers_;
|
||||
auto j = std::find_if(inputs_including_initializers.begin(), inputs_including_initializers.end(),
|
||||
[&name](const NodeArg* input) { return input->Name() == name; });
|
||||
auto& inputs_including_initializers = graph_inputs_including_initializers_;
|
||||
auto j = std::find_if(inputs_including_initializers.begin(), inputs_including_initializers.end(),
|
||||
[&name](const NodeArg* input) { return input->Name() == name; });
|
||||
|
||||
if (j != inputs_including_initializers.end()) {
|
||||
inputs_including_initializers.erase(j);
|
||||
if (j != inputs_including_initializers.end()) {
|
||||
inputs_including_initializers.erase(j);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -932,7 +932,7 @@ TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompr
|
|||
node_9 (Merge)
|
||||
|
|
||||
*/
|
||||
|
||||
|
||||
TypeProto tensor_int32;
|
||||
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
|
||||
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
|
@ -962,10 +962,10 @@ TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompr
|
|||
|
||||
auto& compress_node1 = graph.AddNode("compress_1", "Identity_Fake", "compress node 1", {&output_arg1}, {&output_arg5});
|
||||
compress_node1.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_HIGH));
|
||||
|
||||
|
||||
auto& decompress_node1 = graph.AddNode("decompress_1", "Identity_Fake", "decompress node 1", {&output_arg5}, {&output_arg6});
|
||||
decompress_node1.SetPriority(10); // lower number means high priority
|
||||
|
||||
decompress_node1.SetPriority(10); // lower number means high priority
|
||||
|
||||
graph.AddNode("node_7", "Identity_Fake", "node 7", {&output_arg4}, {&output_arg7});
|
||||
graph.AddNode("node_8", "Merge_Fake", "node 8", {&output_arg7, &output_arg6}, {&output_arg8});
|
||||
graph.AddNode("node_9", "Merge_Fake", "node 9", {&output_arg8, &output_arg3}, {&output_arg9});
|
||||
|
|
@ -1908,5 +1908,21 @@ TEST_F(GraphTest, LoadModelMissingInput) {
|
|||
"initializer, or output of a previous node."));
|
||||
}
|
||||
|
||||
// if an initializer is backing an optional graph input, it can't be removed even if unused in the graph.
|
||||
TEST_F(GraphTest, DontRemoveUnusedInitializerWithGraphInput) {
|
||||
const std::string unused_initializer_name("truncation:0");
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_STATUS_OK(Model::Load(ORT_TSTR("testdata/unused_initializer.onnx"), model, nullptr, *logger_));
|
||||
|
||||
auto& graph = model->MainGraph();
|
||||
const auto& inputs_including_initializers = graph.GetInputsIncludingInitializers();
|
||||
auto j = std::find_if(inputs_including_initializers.cbegin(), inputs_including_initializers.cend(),
|
||||
[&unused_initializer_name](const NodeArg* input) {
|
||||
return input->Name() == unused_initializer_name;
|
||||
});
|
||||
|
||||
ASSERT_NE(j, inputs_including_initializers.cend()) << "Unused initializer was incorrectly removed.";
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/unused_initializer.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/unused_initializer.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue