mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add option to run python unittests in parallel (#37180)
Summary: So far results looks quite promising: test_nn is purely sequential tests and can be accelerated 3x Pull Request resolved: https://github.com/pytorch/pytorch/pull/37180 Differential Revision: D21437871 Pulled By: malfet fbshipit-source-id: 8679a8af355f839f2c9dae3bf36d2e102af05425
This commit is contained in:
parent
681c6fb60f
commit
72e5b7ae5b
2 changed files with 74 additions and 25 deletions
|
|
@ -110,6 +110,21 @@ ROCM_BLACKLIST = [
|
|||
'test_type_hints',
|
||||
]
|
||||
|
||||
RUN_PARALLEL_BLACKLIST = [
|
||||
'test_cpp_extensions_jit',
|
||||
'test_docs_coverage',
|
||||
'test_expecttest',
|
||||
'test_jit_disabled',
|
||||
'test_mobile_optimizer',
|
||||
'test_multiprocessing',
|
||||
'test_multiprocessing_spawn',
|
||||
'test_namedtuple_return_api',
|
||||
'test_overrides',
|
||||
'test_show_pickle',
|
||||
'test_tensorexpr',
|
||||
'test_cuda_primary_ctx',
|
||||
] + [test for test in TESTS if test.startswith('distributed/')]
|
||||
|
||||
# These tests are slow enough that it's worth calculating whether the patch
|
||||
# touched any related files first.
|
||||
SLOW_TESTS = [
|
||||
|
|
@ -186,6 +201,8 @@ def run_test(executable, test_module, test_directory, options, *extra_unittest_a
|
|||
unittest_args = options.additional_unittest_args
|
||||
if options.verbose:
|
||||
unittest_args.append('--verbose')
|
||||
if test_module in RUN_PARALLEL_BLACKLIST:
|
||||
unittest_args = [arg for arg in unittest_args if not arg.startswith('--run-parallel')]
|
||||
# Can't call `python -m unittest test_*` here because it doesn't run code
|
||||
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
|
||||
argv = [test_module + '.py'] + unittest_args + list(extra_unittest_args)
|
||||
|
|
|
|||
|
|
@ -122,6 +122,9 @@ parser.add_argument('--test_bailouts', action='store_true')
|
|||
parser.add_argument('--save-xml', nargs='?', type=str,
|
||||
const=_get_test_report_path(),
|
||||
default=_get_test_report_path() if bool(os.environ.get('IN_CIRCLECI')) else None)
|
||||
parser.add_argument('--discover-tests', action='store_true')
|
||||
parser.add_argument('--log-suffix', type=str, default="")
|
||||
parser.add_argument('--run-parallel', type=int, default=1)
|
||||
|
||||
GRAPH_EXECUTOR = ProfilingMode.SIMPLE if IS_SANDCASTLE else ProfilingMode.PROFILING
|
||||
args, remaining = parser.parse_known_args()
|
||||
|
|
@ -132,7 +135,10 @@ elif args.ge_config == 'profiling':
|
|||
else:
|
||||
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
|
||||
|
||||
LOG_SUFFIX = args.log_suffix
|
||||
RUN_PARALLEL = args.run_parallel
|
||||
TEST_BAILOUTS = args.test_bailouts
|
||||
TEST_DISCOVER = args.discover_tests
|
||||
TEST_IN_SUBPROCESS = args.subprocess
|
||||
TEST_SAVE_XML = args.save_xml
|
||||
REPEAT_COUNT = args.repeat
|
||||
|
|
@ -142,19 +148,7 @@ if not expecttest.ACCEPT:
|
|||
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
|
||||
def shell(command, cwd=None, env=None):
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
# The following cool snippet is copied from Py3 core library subprocess.call
|
||||
# only the with
|
||||
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
|
||||
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
|
||||
# `p.wait()` in a `final` block for the code to be portable.
|
||||
#
|
||||
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
|
||||
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
|
||||
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env)
|
||||
def wait_for_process(p):
|
||||
try:
|
||||
return p.wait()
|
||||
except KeyboardInterrupt:
|
||||
|
|
@ -173,6 +167,20 @@ def shell(command, cwd=None, env=None):
|
|||
# Always call p.wait() to ensure exit
|
||||
p.wait()
|
||||
|
||||
def shell(command, cwd=None, env=None):
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
# The following cool snippet is copied from Py3 core library subprocess.call
|
||||
# only the with
|
||||
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
|
||||
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
|
||||
# `p.wait()` in a `final` block for the code to be portable.
|
||||
#
|
||||
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
|
||||
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
|
||||
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env)
|
||||
return wait_for_process(p)
|
||||
|
||||
|
||||
# Used to run the same test with different tensor types
|
||||
def repeat_test_for_types(dtypes):
|
||||
|
|
@ -195,19 +203,31 @@ IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI'))
|
|||
PY3 = sys.version_info > (3, 0)
|
||||
PY34 = sys.version_info >= (3, 4)
|
||||
|
||||
def discover_test_cases_recursively(suite_or_case):
|
||||
if isinstance(suite_or_case, unittest.TestCase):
|
||||
return [suite_or_case]
|
||||
rc = []
|
||||
for element in suite_or_case:
|
||||
rc.extend(discover_test_cases_recursively(element))
|
||||
return rc
|
||||
|
||||
def get_test_names(test_cases):
|
||||
return ['.'.join(case.id().split('.')[-2:]) for case in test_cases]
|
||||
|
||||
def chunk_list(lst, nchunks):
|
||||
return [lst[i::nchunks] for i in range(nchunks)]
|
||||
|
||||
|
||||
|
||||
def run_tests(argv=UNITTEST_ARGS):
|
||||
if TEST_IN_SUBPROCESS:
|
||||
if TEST_DISCOVER:
|
||||
suite = unittest.TestLoader().loadTestsFromModule(__main__)
|
||||
test_cases = []
|
||||
|
||||
def add_to_test_cases(suite_or_case):
|
||||
if isinstance(suite_or_case, unittest.TestCase):
|
||||
test_cases.append(suite_or_case)
|
||||
else:
|
||||
for element in suite_or_case:
|
||||
add_to_test_cases(element)
|
||||
|
||||
add_to_test_cases(suite)
|
||||
test_cases = discover_test_cases_recursively(suite)
|
||||
for name in get_test_names(test_cases):
|
||||
print(name)
|
||||
elif TEST_IN_SUBPROCESS:
|
||||
suite = unittest.TestLoader().loadTestsFromModule(__main__)
|
||||
test_cases = discover_test_cases_recursively(suite)
|
||||
failed_tests = []
|
||||
for case in test_cases:
|
||||
test_case_full_name = case.id().split('.', 1)[1]
|
||||
|
|
@ -217,10 +237,22 @@ def run_tests(argv=UNITTEST_ARGS):
|
|||
|
||||
assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
|
||||
len(failed_tests), '\n\t'.join(failed_tests))
|
||||
elif RUN_PARALLEL > 1:
|
||||
suite = unittest.TestLoader().loadTestsFromModule(__main__)
|
||||
test_cases = discover_test_cases_recursively(suite)
|
||||
test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL)
|
||||
processes = []
|
||||
for i in range(RUN_PARALLEL):
|
||||
command = [sys.executable] + argv + ['--log-suffix=-shard-{}'.format(i + 1)] + test_batches[i]
|
||||
processes.append(subprocess.Popen(command, universal_newlines=True))
|
||||
failed = False
|
||||
for p in processes:
|
||||
failed |= wait_for_process(p) != 0
|
||||
assert not failed, "Some test shards have failed"
|
||||
elif TEST_SAVE_XML is not None:
|
||||
# import here so that non-CI doesn't need xmlrunner installed
|
||||
import xmlrunner
|
||||
test_report_path = TEST_SAVE_XML
|
||||
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
||||
os.makedirs(test_report_path, exist_ok=True)
|
||||
verbose = '--verbose' in argv or '-v' in argv
|
||||
if verbose:
|
||||
|
|
|
|||
Loading…
Reference in a new issue