fix training distributed ci failure (#9273)

This commit is contained in:
Guoyu Wang 2021-10-05 15:36:44 -07:00 committed by GitHub
parent 35c2102cfa
commit a4d53c4ab5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 3 deletions

View file

@ -3348,6 +3348,8 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const
void Graph::CleanUnusedInitializersAndNodeArgs(const std::unordered_set<std::string>* initializer_names_to_preserve) {
// Node Args being used
std::unordered_set<const NodeArg*> used_args;
used_args.reserve(node_args_.size());
//Node Args we want to preserved even not being used
std::unordered_set<const NodeArg*> node_args_to_preserve;
if (initializer_names_to_preserve) {
@ -3461,8 +3463,12 @@ void Graph::CleanUnusedInitializersAndNodeArgs(const std::unordered_set<std::str
auto current_entry = it++;
const auto* current_node_arg = current_entry->second.get();
const auto& node_arg_name = current_entry->first;
// For some reason, we still have some code hold the raw pointer to the unused NodeArgs,
// Remove only the NodeArgs with no type for now
// TODO, investigate the issue when running using mpirun
if (!node_arg_name.empty() && used_args.find(current_node_arg) == used_args_end &&
node_args_to_preserve.find(current_node_arg) == node_args_to_preserve_end) {
node_args_to_preserve.find(current_node_arg) == node_args_to_preserve_end &&
!current_node_arg->ToProto().has_type()) {
LOGS(logger_, INFO) << "Removing NodeArg '" << node_arg_name << "'. It is no longer used by any node.";
// Need to remove the NodeArg from both value_info_ and node_args_
value_info_.erase(current_node_arg);

View file

@ -1296,8 +1296,11 @@ TEST_F(GraphTest, UnusedInitializerAndNodeArgsAreIgnored) {
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
ASSERT_TRUE(graph.GetAllInitializedTensors().empty());
ASSERT_EQ(nullptr, graph.GetNodeArg(unused_node_arg_name));
// Verify NodeArg from the unused initializer is deleted as well
ASSERT_EQ(nullptr, graph.GetNodeArg(unused_initializer_name));
// TODO, enable this when we can remove unused NodeArgs with type
// See Graph::CleanUnusedInitializersAndNodeArgs
// ASSERT_EQ(nullptr, graph.GetNodeArg(unused_initializer_name));
// serialize and reload so we check the loaded from proto path in SetGraphInputsOutputs
auto proto = model.ToProto();
@ -1317,7 +1320,9 @@ TEST_F(GraphTest, UnusedInitializerAndNodeArgsAreIgnored) {
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
ASSERT_TRUE(graph.GetAllInitializedTensors().empty());
ASSERT_EQ(nullptr, graph.GetNodeArg(unused_node_arg_name));
ASSERT_EQ(nullptr, graph.GetNodeArg(unused_initializer_name));
// TODO, enable this when we can remove unused NodeArgs with type
// See Graph::CleanUnusedInitializersAndNodeArgs
// ASSERT_EQ(nullptr, graph.GetNodeArg(unused_initializer_name));
}
#if !defined(DISABLE_SPARSE_TENSORS)