diff --git a/pytest.ini b/pytest.ini index b7e76081694..e2ab2ebd0cc 100644 --- a/pytest.ini +++ b/pytest.ini @@ -19,3 +19,6 @@ filterwarnings = ignore:Module already imported so cannot be rewritten.*hypothesis:pytest.PytestAssertRewriteWarning xfail_strict = True + +markers = + serial: marks tests as needs to be run serially (deselect with '-m "not serial"') diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 786428e6905..1bda97a3fed 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -70,6 +70,7 @@ from torch.testing._internal.common_utils import ( IS_WINDOWS, IS_X86, parametrize, + serialTest, skipIfRocm, subtest, TEST_WITH_ASAN, @@ -9278,6 +9279,7 @@ class CommonTemplate: @config.patch( "triton.autotune_pointwise", True ) # needed to introduce config that exceed max shared memory usage + @serialTest() def test_large_block_sizes(self): """ Inductor will try triton configs like x = 64 and y = 1024 which will diff --git a/test/run_test.py b/test/run_test.py index 8952c78d6bf..cafa60bd1c8 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -246,9 +246,6 @@ CI_SERIAL_LIST = [ "test_module_hooks", # OOM "inductor/test_max_autotune", "inductor/test_cutlass_backend", # slow due to many nvcc compilation steps - "inductor/test_torchinductor", # OOM on test_large_block_sizes - "inductor/test_torchinductor_dynamic_shapes", # OOM on test_large_block_sizes - "inductor/test_torchinductor_codegen_dynamic_shapes", # OOM on test_large_block_sizes "test_profiler", # test_source_multithreaded is probably not compatible with parallelism ] # A subset of onnx tests that cannot run in parallel due to high memory usage. @@ -1591,6 +1588,11 @@ def run_tests( ): pool.terminate() + keep_going_message = ( + "\n\nTip: You can keep running tests even on failure by passing --keep-going to run_test.py.\n" + "If running on CI, add the 'keep-going' label to your PR and rerun your jobs." + ) + try: for test in selected_tests_serial: options_clone = copy.deepcopy(options) @@ -1603,19 +1605,29 @@ def run_tests( and not options.continue_through_error and not RERUN_DISABLED_TESTS ): - raise RuntimeError( - failure.message - + "\n\nTip: You can keep running tests even on failure by " - "passing --keep-going to run_test.py.\n" - "If running on CI, add the 'keep-going' label to " - "your PR and rerun your jobs." - ) + raise RuntimeError(failure.message + keep_going_message) + + # Run tests marked as serial first + for test in selected_tests_parallel: + options_clone = copy.deepcopy(options) + if can_run_in_pytest(test): + options_clone.pytest = True + options_clone.additional_unittest_args.extend(["-m", "serial"]) + failure = run_test_module(test, test_directory, options_clone) + test_failed = handle_error_messages(failure) + if ( + test_failed + and not options.continue_through_error + and not RERUN_DISABLED_TESTS + ): + raise RuntimeError(failure.message + keep_going_message) os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS) for test in selected_tests_parallel: options_clone = copy.deepcopy(options) if can_run_in_pytest(test): options_clone.pytest = True + options_clone.additional_unittest_args.extend(["-m", "not serial"]) pool.apply_async( run_test_module, args=(test, test_directory, options_clone), @@ -1718,6 +1730,7 @@ def main(): if IS_CI: gen_ci_artifact([x.to_json() for x in include], [x.to_json() for x in exclude]) + print_to_stderr(f"Running parallel tests on {NUM_PROCS} processes") print_to_stderr(test_batch) print_to_stderr(test_batch_exclude) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 776a9d2d3ea..55273870806 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -97,6 +97,11 @@ from torch.utils._import_utils import _check_module_exists import torch.utils._pytree as pytree from .composite_compliance import no_dispatch +try: + import pytest + has_pytest = True +except ImportError: + has_pytest = False # Class to keep track of test flags configurable by environment variables. @@ -1384,6 +1389,15 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor", return decorator +def serialTest(condition=True): + """ + Decorator for running tests serially. Requires pytest + """ + def decorator(fn): + if has_pytest and condition: + return pytest.mark.serial(fn) + return fn + return decorator def unMarkDynamoStrictTest(cls=None): def decorator(cls):