diff --git a/test/run_test.py b/test/run_test.py index e6daef1d3ff..8f2b22ac215 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -305,6 +305,20 @@ CORE_TEST_LIST = [ "test_torch" ] +# A list of distributed tests that run on multiple backends, i.e. gloo, nccl. These backends are spread out +# among all available shards to speed up the test. The list of backends are copied from the tests themselves +DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS = { + "distributed/test_distributed_spawn": [ + "gloo", + "nccl", + "ucc", + ], + "distributed/algorithms/quantization/test_quantization": [ + "gloo", + "nccl", + ], +} + # if a test file takes longer than 5 min, we add it to TARGET_DET_LIST SLOW_TEST_THRESHOLD = 300 @@ -520,12 +534,26 @@ def test_distributed(test_module, test_directory, options): ) == 0 and sys.version_info < (3, 9) if options.verbose and not mpi_available: print_to_stderr("MPI not available -- MPI backend tests will be skipped") + + if options.shard: + which_shard, num_shards = options.shard + else: + which_shard = num_shards = 1 + # Round-robin all backends to different shards + backend_to_shard = {backend: i % num_shards + 1 + for i, backend in enumerate(DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS[test_module])} + print_to_stderr(f"Map different backends to different shards for {test_module}: {backend_to_shard}") + config = DISTRIBUTED_TESTS_CONFIG for backend, env_vars in config.items(): if sys.platform == "win32" and backend != "gloo": continue if backend == "mpi" and not mpi_available: continue + # Default to the first shard if seeing an unrecognized backend + if which_shard != backend_to_shard.get(backend, 1): + print_to_stderr(f"Shard {which_shard}: {backend} should be run in {backend_to_shard.get(backend, 1)}") + continue for with_init_file in {True, False}: if sys.platform == "win32" and not with_init_file: continue @@ -534,8 +562,8 @@ def test_distributed(test_module, test_directory, options): init_str = "with {} init_method" with_init = init_str.format("file" if with_init_file else "env") print_to_stderr( - "Running distributed tests for the {} backend {}".format( - backend, with_init + "Running distributed tests for the {} backend {} in shard {} of {}".format( + backend, with_init, which_shard, num_shards ) ) old_environ = dict(os.environ) @@ -978,7 +1006,9 @@ def get_selected_tests(options): if options.distributed_tests: selected_tests = list( - filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests) + filter(lambda test_name: (test_name in DISTRIBUTED_TESTS and + test_name not in DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS), + selected_tests) ) # Filter to only run core tests when --core option is specified @@ -1088,6 +1118,11 @@ def get_selected_tests(options): selected_tests = exclude_tests(TESTS_REQUIRING_LAPACK, selected_tests, "PyTorch is built without LAPACK support.") + if options.distributed_tests: + # Run distributed tests with multiple backends across all shards, one per backend + selected_tests.extend(DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS.keys()) + selected_tests.reverse() + return selected_tests