From 72e5b7ae5b3abfcc672ec2677e6238baa242bcd6 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 6 May 2020 22:08:45 -0700 Subject: [PATCH] Add option to run python unittests in parallel (#37180) Summary: So far results looks quite promising: test_nn is purely sequential tests and can be accelerated 3x Pull Request resolved: https://github.com/pytorch/pytorch/pull/37180 Differential Revision: D21437871 Pulled By: malfet fbshipit-source-id: 8679a8af355f839f2c9dae3bf36d2e102af05425 --- test/run_test.py | 17 +++++ torch/testing/_internal/common_utils.py | 82 +++++++++++++++++-------- 2 files changed, 74 insertions(+), 25 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index d6fa2fc83f3..10da6e8fa1e 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -110,6 +110,21 @@ ROCM_BLACKLIST = [ 'test_type_hints', ] +RUN_PARALLEL_BLACKLIST = [ + 'test_cpp_extensions_jit', + 'test_docs_coverage', + 'test_expecttest', + 'test_jit_disabled', + 'test_mobile_optimizer', + 'test_multiprocessing', + 'test_multiprocessing_spawn', + 'test_namedtuple_return_api', + 'test_overrides', + 'test_show_pickle', + 'test_tensorexpr', + 'test_cuda_primary_ctx', +] + [test for test in TESTS if test.startswith('distributed/')] + # These tests are slow enough that it's worth calculating whether the patch # touched any related files first. SLOW_TESTS = [ @@ -186,6 +201,8 @@ def run_test(executable, test_module, test_directory, options, *extra_unittest_a unittest_args = options.additional_unittest_args if options.verbose: unittest_args.append('--verbose') + if test_module in RUN_PARALLEL_BLACKLIST: + unittest_args = [arg for arg in unittest_args if not arg.startswith('--run-parallel')] # Can't call `python -m unittest test_*` here because it doesn't run code # in `if __name__ == '__main__': `. So call `python test_*.py` instead. argv = [test_module + '.py'] + unittest_args + list(extra_unittest_args) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 9b56daa719f..eaf2eb54bc8 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -122,6 +122,9 @@ parser.add_argument('--test_bailouts', action='store_true') parser.add_argument('--save-xml', nargs='?', type=str, const=_get_test_report_path(), default=_get_test_report_path() if bool(os.environ.get('IN_CIRCLECI')) else None) +parser.add_argument('--discover-tests', action='store_true') +parser.add_argument('--log-suffix', type=str, default="") +parser.add_argument('--run-parallel', type=int, default=1) GRAPH_EXECUTOR = ProfilingMode.SIMPLE if IS_SANDCASTLE else ProfilingMode.PROFILING args, remaining = parser.parse_known_args() @@ -132,7 +135,10 @@ elif args.ge_config == 'profiling': else: GRAPH_EXECUTOR = ProfilingMode.SIMPLE +LOG_SUFFIX = args.log_suffix +RUN_PARALLEL = args.run_parallel TEST_BAILOUTS = args.test_bailouts +TEST_DISCOVER = args.discover_tests TEST_IN_SUBPROCESS = args.subprocess TEST_SAVE_XML = args.save_xml REPEAT_COUNT = args.repeat @@ -142,19 +148,7 @@ if not expecttest.ACCEPT: UNITTEST_ARGS = [sys.argv[0]] + remaining torch.manual_seed(SEED) - -def shell(command, cwd=None, env=None): - sys.stdout.flush() - sys.stderr.flush() - # The following cool snippet is copied from Py3 core library subprocess.call - # only the with - # 1. `except KeyboardInterrupt` block added for SIGINT handling. - # 2. In Py2, subprocess.Popen doesn't return a context manager, so we do - # `p.wait()` in a `final` block for the code to be portable. - # - # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323 - assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens" - p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env) +def wait_for_process(p): try: return p.wait() except KeyboardInterrupt: @@ -173,6 +167,20 @@ def shell(command, cwd=None, env=None): # Always call p.wait() to ensure exit p.wait() +def shell(command, cwd=None, env=None): + sys.stdout.flush() + sys.stderr.flush() + # The following cool snippet is copied from Py3 core library subprocess.call + # only the with + # 1. `except KeyboardInterrupt` block added for SIGINT handling. + # 2. In Py2, subprocess.Popen doesn't return a context manager, so we do + # `p.wait()` in a `final` block for the code to be portable. + # + # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323 + assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens" + p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env) + return wait_for_process(p) + # Used to run the same test with different tensor types def repeat_test_for_types(dtypes): @@ -195,19 +203,31 @@ IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI')) PY3 = sys.version_info > (3, 0) PY34 = sys.version_info >= (3, 4) +def discover_test_cases_recursively(suite_or_case): + if isinstance(suite_or_case, unittest.TestCase): + return [suite_or_case] + rc = [] + for element in suite_or_case: + rc.extend(discover_test_cases_recursively(element)) + return rc + +def get_test_names(test_cases): + return ['.'.join(case.id().split('.')[-2:]) for case in test_cases] + +def chunk_list(lst, nchunks): + return [lst[i::nchunks] for i in range(nchunks)] + + + def run_tests(argv=UNITTEST_ARGS): - if TEST_IN_SUBPROCESS: + if TEST_DISCOVER: suite = unittest.TestLoader().loadTestsFromModule(__main__) - test_cases = [] - - def add_to_test_cases(suite_or_case): - if isinstance(suite_or_case, unittest.TestCase): - test_cases.append(suite_or_case) - else: - for element in suite_or_case: - add_to_test_cases(element) - - add_to_test_cases(suite) + test_cases = discover_test_cases_recursively(suite) + for name in get_test_names(test_cases): + print(name) + elif TEST_IN_SUBPROCESS: + suite = unittest.TestLoader().loadTestsFromModule(__main__) + test_cases = discover_test_cases_recursively(suite) failed_tests = [] for case in test_cases: test_case_full_name = case.id().split('.', 1)[1] @@ -217,10 +237,22 @@ def run_tests(argv=UNITTEST_ARGS): assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format( len(failed_tests), '\n\t'.join(failed_tests)) + elif RUN_PARALLEL > 1: + suite = unittest.TestLoader().loadTestsFromModule(__main__) + test_cases = discover_test_cases_recursively(suite) + test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL) + processes = [] + for i in range(RUN_PARALLEL): + command = [sys.executable] + argv + ['--log-suffix=-shard-{}'.format(i + 1)] + test_batches[i] + processes.append(subprocess.Popen(command, universal_newlines=True)) + failed = False + for p in processes: + failed |= wait_for_process(p) != 0 + assert not failed, "Some test shards have failed" elif TEST_SAVE_XML is not None: # import here so that non-CI doesn't need xmlrunner installed import xmlrunner - test_report_path = TEST_SAVE_XML + test_report_path = TEST_SAVE_XML + LOG_SUFFIX os.makedirs(test_report_path, exist_ok=True) verbose = '--verbose' in argv or '-v' in argv if verbose: