Don't remove an unused initializer if it is overridable. (#7649)

This commit is contained in:
Scott McKay 2021-05-12 07:03:54 +10:00 committed by GitHub
parent c5aeaa9419
commit 8deca24b1a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 18 deletions

View file

@ -3168,13 +3168,21 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const
void Graph::CleanUnusedInitializers(const std::unordered_set<std::string>* initializer_names_to_preserve) {
std::unordered_set<std::string> used_args;
// anything that provides a required graph input (GetInputs), an optional graph input (GetOverridableInitializers)
// or a graph output (GetOutputs) cannot be removed
const auto& inputs = GetInputs();
const auto& overridable_initializers = GetOverridableInitializers();
const auto& outputs = GetOutputs();
std::for_each(inputs.cbegin(), inputs.cend(), [&used_args](const NodeArg* input) {
ORT_IGNORE_RETURN_VALUE(used_args.insert(input->Name()));
});
std::for_each(overridable_initializers.cbegin(), overridable_initializers.cend(),
[&used_args](const NodeArg* input) {
ORT_IGNORE_RETURN_VALUE(used_args.insert(input->Name()));
});
std::for_each(outputs.cbegin(), outputs.cend(), [&used_args](const NodeArg* output) {
ORT_IGNORE_RETURN_VALUE(used_args.insert(output->Name()));
});
@ -3214,23 +3222,28 @@ void Graph::CleanUnusedInitializers(const std::unordered_set<std::string>* initi
[this](const std::string& name) {
RemoveInitializedTensor(name);
// handle edge case where the unused initializer has a matching graph input
auto& proto_inputs = *graph_proto_->mutable_input();
auto i = std::find_if(proto_inputs.begin(), proto_inputs.end(),
[&name](const ONNX_NAMESPACE::ValueInfoProto& input) {
return input.name() == name;
});
// handle edge case where the unused initializer has a matching graph input.
// this can only happen when initializers cannot be overridden via an optional graph input.
// (otherwise this initializer wouldn't be allowed to be removed due to it backing an optional
// graph input).
if (CanOverrideInitializer() == false) {
auto& proto_inputs = *graph_proto_->mutable_input();
auto i = std::find_if(proto_inputs.begin(), proto_inputs.end(),
[&name](const ONNX_NAMESPACE::ValueInfoProto& input) {
return input.name() == name;
});
if (i != proto_inputs.end()) {
RemoveRepeatedFieldEntry(proto_inputs, i);
}
if (i != proto_inputs.end()) {
RemoveRepeatedFieldEntry(proto_inputs, i);
}
auto& inputs_including_initializers = graph_inputs_including_initializers_;
auto j = std::find_if(inputs_including_initializers.begin(), inputs_including_initializers.end(),
[&name](const NodeArg* input) { return input->Name() == name; });
auto& inputs_including_initializers = graph_inputs_including_initializers_;
auto j = std::find_if(inputs_including_initializers.begin(), inputs_including_initializers.end(),
[&name](const NodeArg* input) { return input->Name() == name; });
if (j != inputs_including_initializers.end()) {
inputs_including_initializers.erase(j);
if (j != inputs_including_initializers.end()) {
inputs_including_initializers.erase(j);
}
}
});
}

View file

@ -932,7 +932,7 @@ TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompr
node_9 (Merge)
|
*/
TypeProto tensor_int32;
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
@ -962,10 +962,10 @@ TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompr
auto& compress_node1 = graph.AddNode("compress_1", "Identity_Fake", "compress node 1", {&output_arg1}, {&output_arg5});
compress_node1.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_HIGH));
auto& decompress_node1 = graph.AddNode("decompress_1", "Identity_Fake", "decompress node 1", {&output_arg5}, {&output_arg6});
decompress_node1.SetPriority(10); // lower number means high priority
decompress_node1.SetPriority(10); // lower number means high priority
graph.AddNode("node_7", "Identity_Fake", "node 7", {&output_arg4}, {&output_arg7});
graph.AddNode("node_8", "Merge_Fake", "node 8", {&output_arg7, &output_arg6}, {&output_arg8});
graph.AddNode("node_9", "Merge_Fake", "node 9", {&output_arg8, &output_arg3}, {&output_arg9});
@ -1908,5 +1908,21 @@ TEST_F(GraphTest, LoadModelMissingInput) {
"initializer, or output of a previous node."));
}
// if an initializer is backing an optional graph input, it can't be removed even if unused in the graph.
TEST_F(GraphTest, DontRemoveUnusedInitializerWithGraphInput) {
const std::string unused_initializer_name("truncation:0");
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(ORT_TSTR("testdata/unused_initializer.onnx"), model, nullptr, *logger_));
auto& graph = model->MainGraph();
const auto& inputs_including_initializers = graph.GetInputsIncludingInitializers();
auto j = std::find_if(inputs_including_initializers.cbegin(), inputs_including_initializers.cend(),
[&unused_initializer_name](const NodeArg* input) {
return input->Name() == unused_initializer_name;
});
ASSERT_NE(j, inputs_including_initializers.cend()) << "Unused initializer was incorrectly removed.";
}
} // namespace test
} // namespace onnxruntime

Binary file not shown.