Fix DEBUG_NODE_INPUTS_OUTPUTS test by putting it in a separate process, clean up unused test_main.cc files. (#5949)

Move the DEBUG_NODE_INPUTS_OUTPUTS test into its own process. The implementation uses static variables which do not interact well with other tests.
Clean up old test_main.cc files which are no longer used.
This commit is contained in:
Edward Chen 2020-12-11 11:36:58 -08:00 committed by GitHub
parent a53f4dd379
commit c8ac34d6a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 87 additions and 176 deletions

View file

@ -139,6 +139,9 @@ function(AddTest)
endif()
endfunction(AddTest)
# general program entrypoint for C++ unit tests
set(onnxruntime_unittest_main_src "${TEST_SRC_DIR}/unittest_main/test_main.cc")
#Do not add '${TEST_SRC_DIR}/util/include' to your include directories directly
#Use onnxruntime_add_include_to_target or target_link_libraries, so that compile definitions
#can propagate correctly.
@ -571,65 +574,58 @@ endif()
set(all_dependencies ${onnxruntime_test_providers_dependencies} )
if (onnxruntime_ENABLE_TRAINING)
list(APPEND all_tests ${onnxruntime_test_training_src})
endif()
if (onnxruntime_ENABLE_TRAINING)
list(APPEND all_tests ${onnxruntime_test_training_src})
endif()
if (onnxruntime_USE_TVM)
list(APPEND all_tests ${onnxruntime_test_tvm_src})
endif()
if (onnxruntime_USE_OPENVINO)
list(APPEND all_tests ${onnxruntime_test_openvino_src})
endif()
# we can only have one 'main', so remove them all and add back the providers test_main as it sets
# up everything we need for all tests
file(GLOB_RECURSE test_mains CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/*/test_main.cc"
)
list(REMOVE_ITEM all_tests ${test_mains})
list(APPEND all_tests "${TEST_SRC_DIR}/providers/test_main.cc")
if (onnxruntime_USE_TVM)
list(APPEND all_tests ${onnxruntime_test_tvm_src})
endif()
if (onnxruntime_USE_OPENVINO)
list(APPEND all_tests ${onnxruntime_test_openvino_src})
endif()
# this is only added to onnxruntime_test_framework_libs above, but we use onnxruntime_test_providers_libs for the onnxruntime_test_all target.
# for now, add it here. better is probably to have onnxruntime_test_providers_libs use the full onnxruntime_test_framework_libs
# list given it's built on top of that library and needs all the same dependencies.
if(WIN32)
list(APPEND onnxruntime_test_providers_libs Advapi32)
endif()
# this is only added to onnxruntime_test_framework_libs above, but we use onnxruntime_test_providers_libs for the onnxruntime_test_all target.
# for now, add it here. better is probably to have onnxruntime_test_providers_libs use the full onnxruntime_test_framework_libs
# list given it's built on top of that library and needs all the same dependencies.
if(WIN32)
list(APPEND onnxruntime_test_providers_libs Advapi32)
endif()
AddTest(
TARGET onnxruntime_test_all
SOURCES ${all_tests}
LIBS onnx_test_runner_common ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} re2::re2 onnx_test_data_proto
DEPENDS ${all_dependencies}
)
AddTest(
TARGET onnxruntime_test_all
SOURCES ${all_tests} ${onnxruntime_unittest_main_src}
LIBS onnx_test_runner_common ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} re2::re2 onnx_test_data_proto
DEPENDS ${all_dependencies}
)
# the default logger tests conflict with the need to have an overall default logger
# so skip in this type of
target_compile_definitions(onnxruntime_test_all PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
target_compile_definitions(onnxruntime_test_all_xc PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
target_compile_options(onnxruntime_test_all PUBLIC "-Wno-unused-const-variable")
endif()
if(onnxruntime_RUN_MODELTEST_IN_DEBUG_MODE)
target_compile_definitions(onnxruntime_test_all PUBLIC -DRUN_MODELTEST_IN_DEBUG_MODE)
endif()
if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
target_compile_definitions(onnxruntime_test_all PRIVATE DEBUG_NODE_INPUTS_OUTPUTS)
endif()
if (onnxruntime_USE_FEATURIZERS)
target_include_directories(onnxruntime_test_all PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external/FeaturizersLibrary/src)
endif()
if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
target_link_libraries(onnxruntime_test_all PRIVATE onnxruntime_language_interop onnxruntime_pyop)
endif()
# the default logger tests conflict with the need to have an overall default logger
# so skip in this type of
target_compile_definitions(onnxruntime_test_all PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
target_compile_definitions(onnxruntime_test_all_xc PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
target_compile_options(onnxruntime_test_all PUBLIC "-Wno-unused-const-variable")
endif()
if(onnxruntime_RUN_MODELTEST_IN_DEBUG_MODE)
target_compile_definitions(onnxruntime_test_all PUBLIC -DRUN_MODELTEST_IN_DEBUG_MODE)
endif()
if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
target_compile_definitions(onnxruntime_test_all PRIVATE DEBUG_NODE_INPUTS_OUTPUTS)
endif()
if (onnxruntime_USE_FEATURIZERS)
target_include_directories(onnxruntime_test_all PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external/FeaturizersLibrary/src)
endif()
if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
target_link_libraries(onnxruntime_test_all PRIVATE onnxruntime_language_interop onnxruntime_pyop)
endif()
if (onnxruntime_USE_ROCM)
target_include_directories(onnxruntime_test_all PRIVATE ${onnxruntime_ROCM_HOME}/include/hiprand ${onnxruntime_ROCM_HOME}/include/rocrand)
endif()
if (onnxruntime_USE_ROCM)
target_include_directories(onnxruntime_test_all PRIVATE ${onnxruntime_ROCM_HOME}/include/hiprand ${onnxruntime_ROCM_HOME}/include/rocrand)
endif()
set(test_data_target onnxruntime_test_all)
set(test_data_target onnxruntime_test_all)
#
@ -872,7 +868,7 @@ if (onnxruntime_BUILD_SHARED_LIB)
endif()
AddTest(DYN
TARGET onnxruntime_shared_lib_test
SOURCES ${onnxruntime_shared_lib_test_SRC} ${TEST_SRC_DIR}/providers/test_main.cc
SOURCES ${onnxruntime_shared_lib_test_SRC} ${onnxruntime_unittest_main_src}
LIBS ${onnxruntime_shared_lib_test_LIBS}
DEPENDS ${all_dependencies}
)
@ -905,6 +901,24 @@ if (onnxruntime_BUILD_SHARED_LIB)
endif()
endif()
# the debug node IO functionality uses static variables, so it is best tested
# in its own process
if(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
AddTest(
TARGET onnxruntime_test_debug_node_inputs_outputs
SOURCES
"${TEST_SRC_DIR}/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc"
"${TEST_SRC_DIR}/framework/TestAllocatorManager.cc"
"${TEST_SRC_DIR}/providers/provider_test_utils.cc"
${onnxruntime_unittest_main_src}
LIBS ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs}
DEPENDS ${all_dependencies}
)
target_compile_definitions(onnxruntime_test_debug_node_inputs_outputs
PRIVATE DEBUG_NODE_INPUTS_OUTPUTS)
endif(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
#some ETW tools
if(WIN32 AND onnxruntime_ENABLE_INSTRUMENT)
add_executable(generate_perf_report_from_etl ${ONNXRUNTIME_ROOT}/tool/etw/main.cc

View file

@ -176,41 +176,30 @@ const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables() {
static const NodeDumpOptions node_dump_options = []() {
namespace env_vars = debug_node_inputs_outputs_env_vars;
auto get_bool_env_var = [](const char* env_var) {
const auto val = Env::Default().GetEnvironmentVar(env_var);
if (val.empty()) return false;
std::istringstream s{val};
int i;
ORT_ENFORCE(
s >> i && s.eof(),
"Failed to parse environment variable ", env_var, ": ", val);
return i != 0;
};
NodeDumpOptions opts{};
// Preserve existing behavior of printing the shapes by default. Turn it off only if the user has requested so
// explicitly by setting the value of the env variable to 0.
opts.dump_flags = NodeDumpOptions::DumpFlags::None;
if (ParseEnvironmentVariable<bool>(env_vars::kDumpShapeData, true)) {
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpShapeData, true)) {
opts.dump_flags |= NodeDumpOptions::DumpFlags::Shape;
}
if (get_bool_env_var(env_vars::kDumpInputData)) {
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpInputData, false)) {
opts.dump_flags |= NodeDumpOptions::DumpFlags::InputData;
}
if (get_bool_env_var(env_vars::kDumpOutputData)) {
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpOutputData, false)) {
opts.dump_flags |= NodeDumpOptions::DumpFlags::OutputData;
}
opts.filter.name_pattern = Env::Default().GetEnvironmentVar(env_vars::kNameFilter);
opts.filter.op_type_pattern = Env::Default().GetEnvironmentVar(env_vars::kOpTypeFilter);
if (get_bool_env_var(env_vars::kDumpDataToFiles)) {
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpDataToFiles, false)) {
opts.data_destination = NodeDumpOptions::DataDestination::TensorProtoFiles;
}
if (get_bool_env_var(env_vars::kAppendRankToFileName)) {
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kAppendRankToFileName, false)) {
std::string rank = Env::Default().GetEnvironmentVar("OMPI_COMM_WORLD_RANK");
if (rank.empty()) {
opts.file_suffix = "_default_rank_0";
@ -229,7 +218,7 @@ const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables() {
opts.data_destination == NodeDumpOptions::DataDestination::TensorProtoFiles &&
opts.filter.name_pattern.empty() && opts.filter.op_type_pattern.empty()) {
ORT_ENFORCE(
get_bool_env_var(env_vars::kDumpingDataToFilesForAllNodesIsOk),
ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpingDataToFilesForAllNodesIsOk, false),
"The current environment variable configuration will dump node input or output data to files for every node. "
"This may cause a lot of files to be generated. Set the environment variable ",
env_vars::kDumpingDataToFilesForAllNodesIsOk, " to confirm this is what you want.");

View file

@ -33,7 +33,7 @@ optional<T> ParseEnvironmentVariable(const std::string& name) {
* Parses an environment variable value or returns the given default if unavailable.
*/
template <typename T>
T ParseEnvironmentVariable(const std::string& name, const T& default_value) {
T ParseEnvironmentVariableWithDefault(const std::string& name, const T& default_value) {
const auto parsed = ParseEnvironmentVariable<T>(name);
if (parsed.has_value()) {
return parsed.value();

View file

@ -1,24 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/test_environment.h"
GTEST_API_ int main(int argc, char** argv) {
int status = 0;
ORT_TRY {
const bool create_default_logger = false;
onnxruntime::test::TestEnvironment environment{argc, argv, create_default_logger};
status = RUN_ALL_TESTS();
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
std::cerr << ex.what();
status = -1;
});
}
return status;
}

View file

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
#include "core/framework/debug_node_inputs_outputs_utils.h"
#include <fstream>
@ -41,11 +39,14 @@ TEST(DebugNodeInputsOutputs, BasicFileOutput) {
TemporaryDirectory temp_dir{ORT_TSTR("debug_node_inputs_outputs_utils_test")};
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{env_vars::kDumpInputData, {"1"}},
{env_vars::kDumpOutputData, {"1"}},
{env_vars::kDumpDataToFiles, {"1"}},
{env_vars::kOutputDir, {ToMBString(temp_dir.Path())}},
{env_vars::kDumpingDataToFilesForAllNodesIsOk, {"1"}},
{env_vars::kDumpInputData, "1"},
{env_vars::kDumpOutputData, "1"},
{env_vars::kNameFilter, nullopt},
{env_vars::kOpTypeFilter, nullopt},
{env_vars::kDumpDataToFiles, "1"},
{env_vars::kAppendRankToFileName, nullopt},
{env_vars::kOutputDir, ToMBString(temp_dir.Path())},
{env_vars::kDumpingDataToFilesForAllNodesIsOk, "1"},
}};
OpTester tester{"Round", 11, kOnnxDomain};
@ -56,8 +57,10 @@ TEST(DebugNodeInputsOutputs, BasicFileOutput) {
auto verify_file_data =
[&temp_dir, &input, &output](
const std::vector<OrtValue>& /*fetches*/,
const std::vector<OrtValue>& fetches,
const std::string& /*provider_type*/) {
ASSERT_EQ(fetches.size(), 1u);
FetchTensor(fetches[0]);
VerifyTensorProtoFileData(
temp_dir.Path() + ORT_TSTR("/x.tensorproto"),
gsl::make_span(input));
@ -73,5 +76,3 @@ TEST(DebugNodeInputsOutputs, BasicFileOutput) {
} // namespace test
} // namespace onnxruntime
#endif

View file

@ -1,25 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/session/environment.h"
#include "gtest/gtest.h"
#include "test/test_environment.h"
#include "core/session/onnxruntime_cxx_api.h"
int main(int argc, char** argv) {
int status = 0;
ORT_TRY {
onnxruntime::test::TestEnvironment test_environment{argc, argv};
status = RUN_ALL_TESTS();
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
std::cerr << ex.what();
status = -1;
});
}
return status;
}

View file

@ -1,31 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/test_environment.h"
#include "core/graph/constants.h"
#include "core/graph/op.h"
GTEST_API_ int main(int argc, char** argv) {
int status = 0;
ORT_TRY {
onnxruntime::test::TestEnvironment test_environment{argc, argv};
// Register Microsoft domain with min/max op_set version as 1/1.
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(onnxruntime_ir::kMSDomain, 1, 1);
// Register Microsoft domain ops.
onnxruntime_ir::MsOpRegistry::RegisterMsOps();
status = RUN_ALL_TESTS();
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&ex]() {
std::cerr << ex.what();
status = -1;
});
}
return status;
}

View file

@ -1,13 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <gtest/gtest.h>
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
//TODO: Linker on Mac OS X is kind of strange. The next line of code will trigger a crash
#ifndef __APPLE__
::google::protobuf::ShutdownProtobufLibrary();
#endif
return ret;
}

View file

@ -24,7 +24,7 @@ RandomSeedType GetTestRandomSeed() {
};
static const auto use_cached =
!ParseEnvironmentVariable<bool>(test_random_seed_env_vars::kDoNotCache, false);
!ParseEnvironmentVariableWithDefault<bool>(test_random_seed_env_vars::kDoNotCache, false);
if (use_cached) {
// initially generate from current time
static const auto static_random_seed = generate_from_time();