From a4d53c4ab5d097497da16fe3e788d034d2c372a5 Mon Sep 17 00:00:00 2001 From: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Date: Tue, 5 Oct 2021 15:36:44 -0700 Subject: [PATCH] fix training distributed ci failure (#9273) --- onnxruntime/core/graph/graph.cc | 8 +++++++- onnxruntime/test/ir/graph_test.cc | 9 +++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 97be170865..1b989621f8 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3348,6 +3348,8 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const void Graph::CleanUnusedInitializersAndNodeArgs(const std::unordered_set* initializer_names_to_preserve) { // Node Args being used std::unordered_set used_args; + used_args.reserve(node_args_.size()); + //Node Args we want to preserved even not being used std::unordered_set node_args_to_preserve; if (initializer_names_to_preserve) { @@ -3461,8 +3463,12 @@ void Graph::CleanUnusedInitializersAndNodeArgs(const std::unordered_setsecond.get(); const auto& node_arg_name = current_entry->first; + // For some reason, we still have some code hold the raw pointer to the unused NodeArgs, + // Remove only the NodeArgs with no type for now + // TODO, investigate the issue when running using mpirun if (!node_arg_name.empty() && used_args.find(current_node_arg) == used_args_end && - node_args_to_preserve.find(current_node_arg) == node_args_to_preserve_end) { + node_args_to_preserve.find(current_node_arg) == node_args_to_preserve_end && + !current_node_arg->ToProto().has_type()) { LOGS(logger_, INFO) << "Removing NodeArg '" << node_arg_name << "'. It is no longer used by any node."; // Need to remove the NodeArg from both value_info_ and node_args_ value_info_.erase(current_node_arg); diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 9403f4ac12..5faae75209 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1296,8 +1296,11 @@ TEST_F(GraphTest, UnusedInitializerAndNodeArgsAreIgnored) { ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); ASSERT_TRUE(graph.GetAllInitializedTensors().empty()); ASSERT_EQ(nullptr, graph.GetNodeArg(unused_node_arg_name)); + // Verify NodeArg from the unused initializer is deleted as well - ASSERT_EQ(nullptr, graph.GetNodeArg(unused_initializer_name)); + // TODO, enable this when we can remove unused NodeArgs with type + // See Graph::CleanUnusedInitializersAndNodeArgs + // ASSERT_EQ(nullptr, graph.GetNodeArg(unused_initializer_name)); // serialize and reload so we check the loaded from proto path in SetGraphInputsOutputs auto proto = model.ToProto(); @@ -1317,7 +1320,9 @@ TEST_F(GraphTest, UnusedInitializerAndNodeArgsAreIgnored) { EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); ASSERT_TRUE(graph.GetAllInitializedTensors().empty()); ASSERT_EQ(nullptr, graph.GetNodeArg(unused_node_arg_name)); - ASSERT_EQ(nullptr, graph.GetNodeArg(unused_initializer_name)); + // TODO, enable this when we can remove unused NodeArgs with type + // See Graph::CleanUnusedInitializersAndNodeArgs + // ASSERT_EQ(nullptr, graph.GetNodeArg(unused_initializer_name)); } #if !defined(DISABLE_SPARSE_TENSORS)