Add option to run python unittests in parallel (#37180)

Summary:
So far results looks quite promising: test_nn is purely sequential tests and can be accelerated 3x
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37180

Differential Revision: D21437871

Pulled By: malfet

fbshipit-source-id: 8679a8af355f839f2c9dae3bf36d2e102af05425
This commit is contained in:
Nikita Shulga 2020-05-06 22:08:45 -07:00 committed by Facebook GitHub Bot
parent 681c6fb60f
commit 72e5b7ae5b
2 changed files with 74 additions and 25 deletions

View file

@ -110,6 +110,21 @@ ROCM_BLACKLIST = [
'test_type_hints',
]
RUN_PARALLEL_BLACKLIST = [
'test_cpp_extensions_jit',
'test_docs_coverage',
'test_expecttest',
'test_jit_disabled',
'test_mobile_optimizer',
'test_multiprocessing',
'test_multiprocessing_spawn',
'test_namedtuple_return_api',
'test_overrides',
'test_show_pickle',
'test_tensorexpr',
'test_cuda_primary_ctx',
] + [test for test in TESTS if test.startswith('distributed/')]
# These tests are slow enough that it's worth calculating whether the patch
# touched any related files first.
SLOW_TESTS = [
@ -186,6 +201,8 @@ def run_test(executable, test_module, test_directory, options, *extra_unittest_a
unittest_args = options.additional_unittest_args
if options.verbose:
unittest_args.append('--verbose')
if test_module in RUN_PARALLEL_BLACKLIST:
unittest_args = [arg for arg in unittest_args if not arg.startswith('--run-parallel')]
# Can't call `python -m unittest test_*` here because it doesn't run code
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
argv = [test_module + '.py'] + unittest_args + list(extra_unittest_args)

View file

@ -122,6 +122,9 @@ parser.add_argument('--test_bailouts', action='store_true')
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if bool(os.environ.get('IN_CIRCLECI')) else None)
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)
GRAPH_EXECUTOR = ProfilingMode.SIMPLE if IS_SANDCASTLE else ProfilingMode.PROFILING
args, remaining = parser.parse_known_args()
@ -132,7 +135,10 @@ elif args.ge_config == 'profiling':
else:
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
LOG_SUFFIX = args.log_suffix
RUN_PARALLEL = args.run_parallel
TEST_BAILOUTS = args.test_bailouts
TEST_DISCOVER = args.discover_tests
TEST_IN_SUBPROCESS = args.subprocess
TEST_SAVE_XML = args.save_xml
REPEAT_COUNT = args.repeat
@ -142,19 +148,7 @@ if not expecttest.ACCEPT:
UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED)
def shell(command, cwd=None, env=None):
sys.stdout.flush()
sys.stderr.flush()
# The following cool snippet is copied from Py3 core library subprocess.call
# only the with
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
# `p.wait()` in a `final` block for the code to be portable.
#
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env)
def wait_for_process(p):
try:
return p.wait()
except KeyboardInterrupt:
@ -173,6 +167,20 @@ def shell(command, cwd=None, env=None):
# Always call p.wait() to ensure exit
p.wait()
def shell(command, cwd=None, env=None):
sys.stdout.flush()
sys.stderr.flush()
# The following cool snippet is copied from Py3 core library subprocess.call
# only the with
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
# `p.wait()` in a `final` block for the code to be portable.
#
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env)
return wait_for_process(p)
# Used to run the same test with different tensor types
def repeat_test_for_types(dtypes):
@ -195,19 +203,31 @@ IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI'))
PY3 = sys.version_info > (3, 0)
PY34 = sys.version_info >= (3, 4)
def discover_test_cases_recursively(suite_or_case):
if isinstance(suite_or_case, unittest.TestCase):
return [suite_or_case]
rc = []
for element in suite_or_case:
rc.extend(discover_test_cases_recursively(element))
return rc
def get_test_names(test_cases):
return ['.'.join(case.id().split('.')[-2:]) for case in test_cases]
def chunk_list(lst, nchunks):
return [lst[i::nchunks] for i in range(nchunks)]
def run_tests(argv=UNITTEST_ARGS):
if TEST_IN_SUBPROCESS:
if TEST_DISCOVER:
suite = unittest.TestLoader().loadTestsFromModule(__main__)
test_cases = []
def add_to_test_cases(suite_or_case):
if isinstance(suite_or_case, unittest.TestCase):
test_cases.append(suite_or_case)
else:
for element in suite_or_case:
add_to_test_cases(element)
add_to_test_cases(suite)
test_cases = discover_test_cases_recursively(suite)
for name in get_test_names(test_cases):
print(name)
elif TEST_IN_SUBPROCESS:
suite = unittest.TestLoader().loadTestsFromModule(__main__)
test_cases = discover_test_cases_recursively(suite)
failed_tests = []
for case in test_cases:
test_case_full_name = case.id().split('.', 1)[1]
@ -217,10 +237,22 @@ def run_tests(argv=UNITTEST_ARGS):
assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
len(failed_tests), '\n\t'.join(failed_tests))
elif RUN_PARALLEL > 1:
suite = unittest.TestLoader().loadTestsFromModule(__main__)
test_cases = discover_test_cases_recursively(suite)
test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL)
processes = []
for i in range(RUN_PARALLEL):
command = [sys.executable] + argv + ['--log-suffix=-shard-{}'.format(i + 1)] + test_batches[i]
processes.append(subprocess.Popen(command, universal_newlines=True))
failed = False
for p in processes:
failed |= wait_for_process(p) != 0
assert not failed, "Some test shards have failed"
elif TEST_SAVE_XML is not None:
# import here so that non-CI doesn't need xmlrunner installed
import xmlrunner
test_report_path = TEST_SAVE_XML
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
os.makedirs(test_report_path, exist_ok=True)
verbose = '--verbose' in argv or '-v' in argv
if verbose: