mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Improve logic for S3 stats gathering. Uses automatic SLOW_TESTS. (#53549)
Summary: This PR: 1. refactors the logic for S3 stats gathering. 2. Renames SLOW_TESTS to TARGET_DET_LIST to disambiguate and remove confusion with slowTest 2. detects slow tests (tests with time > 5min) to add to the TARGET_DET_LIST based on results in S3 from the previous nightly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/53549 Test Plan: Set CIRCLE_JOB to your favorite CI job (like `pytorch_linux_bionic_py3_8_gcc9_coverage_test1`). Run `python test/run_test.py --determine-from=<your fave pytorch files>` e.g., `python test/run_test.py --determine-from=test/run_test.py` Reviewed By: mrshenli Differential Revision: D26904478 Pulled By: janeyx99 fbshipit-source-id: 9576b34f4fee09291d60e36ff2631753a3925094
This commit is contained in:
parent
1c9fc38eb2
commit
bcbe07200c
2 changed files with 88 additions and 40 deletions
126
test/run_test.py
126
test/run_test.py
|
|
@ -22,6 +22,7 @@ from typing import Dict, Optional, Tuple, List, Any
|
|||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
import botocore # type: ignore[import]
|
||||
import botocore.exceptions # type: ignore[import]
|
||||
HAVE_BOTO3 = True
|
||||
except ImportError:
|
||||
HAVE_BOTO3 = False
|
||||
|
|
@ -235,8 +236,10 @@ WINDOWS_COVERAGE_BLOCKLIST = [
|
|||
|
||||
|
||||
# These tests are slow enough that it's worth calculating whether the patch
|
||||
# touched any related files first.
|
||||
SLOW_TESTS = [
|
||||
# touched any related files first. This list was manually generated, but for every
|
||||
# run with --determine-from, we use another generated list based on this one and the
|
||||
# previous test stats.
|
||||
TARGET_DET_LIST = [
|
||||
'distributions/test_distributions',
|
||||
'test_nn',
|
||||
'test_autograd',
|
||||
|
|
@ -300,6 +303,10 @@ SLOW_TESTS = [
|
|||
'distributed/pipeline/sync/test_transparency',
|
||||
'distributed/pipeline/sync/test_worker',
|
||||
]
|
||||
|
||||
# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST
|
||||
SLOW_TEST_THRESHOLD = 300
|
||||
|
||||
_DEP_MODULES_CACHE: Dict[str, set] = {}
|
||||
|
||||
DISTRIBUTED_TESTS_CONFIG = {}
|
||||
|
|
@ -349,34 +356,42 @@ def print_to_stderr(message):
|
|||
print(message, file=sys.stderr)
|
||||
|
||||
|
||||
# This function returns a list of S3 test time reports. This function can run into errors if HAVE_BOTO3 = False
|
||||
# or the S3 bucket is somehow unavailable. Even though this function goes through ten nightly commits' reports
|
||||
# to find a non-empty report, it is still conceivable (though highly unlikely) for this function to return no reports.
|
||||
def get_test_time_reports_from_S3() -> List[Dict[str, Any]]:
|
||||
try:
|
||||
commit_date_ts = subprocess.check_output(
|
||||
['git', 'show', '-s', '--format=%ct', 'HEAD'],
|
||||
encoding="ascii").strip()
|
||||
commit_date = datetime.fromtimestamp(int(commit_date_ts))
|
||||
day_before_commit = str(commit_date - timedelta(days=1)).split(' ')[0]
|
||||
# something like git rev-list --before="2021-03-04" --max-count=1 --remotes="*origin/nightly"
|
||||
nightly_commit = subprocess.check_output(
|
||||
["git", "rev-list", f"--before={day_before_commit}", "--max-count=1", "--remotes=*origin/nightly"],
|
||||
encoding="ascii").strip()
|
||||
print(f'Using nightly commit: {nightly_commit}')
|
||||
commit_date_ts = subprocess.check_output(
|
||||
['git', 'show', '-s', '--format=%ct', 'HEAD'],
|
||||
encoding="ascii").strip()
|
||||
commit_date = datetime.fromtimestamp(int(commit_date_ts))
|
||||
day_before_commit = str(commit_date - timedelta(days=1)).split(' ')[0]
|
||||
# something like git rev-list --before="2021-03-04" --max-count=10 --remotes="*origin/nightly"
|
||||
nightly_commits = subprocess.check_output(
|
||||
["git", "rev-list", f"--before={day_before_commit}", "--max-count=10", "--remotes=*origin/nightly"],
|
||||
encoding="ascii").splitlines()
|
||||
|
||||
job = os.environ.get("CIRCLE_JOB", "")
|
||||
job_minus_shard_number = job.rstrip('0123456789')
|
||||
job = os.environ.get("CIRCLE_JOB", "")
|
||||
job_minus_shard_number = job.rstrip('0123456789')
|
||||
|
||||
try:
|
||||
s3 = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED))
|
||||
bucket = s3.Bucket(name="ossci-metrics")
|
||||
summaries = bucket.objects.filter(Prefix=f"test_time/{nightly_commit}/{job_minus_shard_number}")
|
||||
except (RuntimeError, RuntimeWarning):
|
||||
print('Failed to read from S3. Proceeding with no reports.')
|
||||
return []
|
||||
|
||||
reports = []
|
||||
for summary in summaries:
|
||||
binary = summary.get()["Body"].read()
|
||||
string = bz2.decompress(binary).decode("utf-8")
|
||||
reports.append(json.loads(string))
|
||||
return reports
|
||||
reports = []
|
||||
commit_index = 0
|
||||
while len(reports) == 0 and commit_index < len(nightly_commits):
|
||||
nightly_commit = nightly_commits[commit_index]
|
||||
print(f'Grabbing reports from nightly commit: {nightly_commit}')
|
||||
summaries = bucket.objects.filter(Prefix=f"test_time/{nightly_commit}/{job_minus_shard_number}")
|
||||
for summary in summaries:
|
||||
binary = summary.get()["Body"].read()
|
||||
string = bz2.decompress(binary).decode("utf-8")
|
||||
reports.append(json.loads(string))
|
||||
commit_index += 1
|
||||
return reports
|
||||
except botocore.exceptions.ClientError as err:
|
||||
print('Error Message: {}'.format(err.response['Error']['Message']))
|
||||
return []
|
||||
|
||||
|
||||
def calculate_job_times(reports: List[Dict[str, Any]]) -> Dict[str, Tuple[float, int]]:
|
||||
|
|
@ -393,17 +408,18 @@ def calculate_job_times(reports: List[Dict[str, Any]]) -> Dict[str, Tuple[float,
|
|||
new_count = curr_count + 1
|
||||
new_avg = (curr_avg * curr_count + test_file['total_seconds']) / new_count
|
||||
jobs_to_times[name] = (new_avg, new_count)
|
||||
|
||||
# if there's 'test_cpp_extensions_aot' entry in jobs_to_times, add 'test_cpp_extensions_aot_ninja'
|
||||
# and 'test_cpp_extensions_aot_no_ninja' duplicate entries to ease future computation since
|
||||
# test_cpp_extensions_aot_no_ninja and test_cpp_extensions_aot_ninja are Python test jobs that
|
||||
# both use the test_cpp_extensions_aot.py file.
|
||||
if 'test_cpp_extensions_aot' in jobs_to_times:
|
||||
jobs_to_times['test_cpp_extensions_aot_ninja'] = jobs_to_times['test_cpp_extensions_aot']
|
||||
jobs_to_times['test_cpp_extensions_aot_no_ninja'] = jobs_to_times['test_cpp_extensions_aot']
|
||||
return jobs_to_times
|
||||
|
||||
|
||||
def calculate_shards(num_shards: int, tests: List[str], job_times: Dict[str, Tuple[float, int]]) -> List[Tuple[float, List[str]]]:
|
||||
# if there's 'test_cpp_extensions_aot' entry in job_times, add 'test_cpp_extensions_aot_ninja'
|
||||
# and 'test_cpp_extensions_aot_no_ninja' duplicate entries to ease future computation since
|
||||
# test_cpp_extensions_aot_no_ninja and test_cpp_extensions_aot_ninja are Python test jobs that
|
||||
# both use the test_cpp_extensions_aot.py file.
|
||||
if 'test_cpp_extensions_aot' in job_times:
|
||||
job_times['test_cpp_extensions_aot_ninja'] = job_times['test_cpp_extensions_aot']
|
||||
job_times['test_cpp_extensions_aot_no_ninja'] = job_times['test_cpp_extensions_aot']
|
||||
filtered_job_times: Dict[str, float] = dict()
|
||||
for test in tests:
|
||||
if test in job_times:
|
||||
|
|
@ -424,21 +440,50 @@ def calculate_shards(num_shards: int, tests: List[str], job_times: Dict[str, Tup
|
|||
return sharded_jobs
|
||||
|
||||
|
||||
def get_shard(which_shard: int, num_shards: int, tests: List[str]) -> List[str]:
|
||||
def pull_job_times_from_S3() -> Dict[str, Tuple[float, int]]:
|
||||
if HAVE_BOTO3:
|
||||
s3_reports = get_test_time_reports_from_S3()
|
||||
else:
|
||||
print('Please install boto3 to enable automatic test sharding.')
|
||||
print('Please install boto3 to enable using S3 test times for automatic sharding and test categorization.')
|
||||
s3_reports = []
|
||||
|
||||
if len(s3_reports) == 0:
|
||||
print('Gathered no reports from S3. Proceeding with default sharding plan.')
|
||||
print('Gathered no reports from S3. Please proceed without them.')
|
||||
return dict()
|
||||
|
||||
return calculate_job_times(s3_reports)
|
||||
|
||||
|
||||
def get_shard(which_shard: int, num_shards: int, tests: List[str]) -> List[str]:
|
||||
jobs_to_times = pull_job_times_from_S3()
|
||||
|
||||
# Got no stats from S3, returning early to save runtime
|
||||
if len(jobs_to_times) == 0:
|
||||
print('Gathered no stats from S3. Proceeding with default sharding plan.')
|
||||
return tests[which_shard - 1 :: num_shards]
|
||||
jobs_to_times = calculate_job_times(s3_reports)
|
||||
|
||||
shards = calculate_shards(num_shards, tests, jobs_to_times)
|
||||
_, tests_from_shard = shards[which_shard - 1]
|
||||
return tests_from_shard
|
||||
|
||||
|
||||
def get_slow_tests_based_on_S3() -> List[str]:
|
||||
jobs_to_times = pull_job_times_from_S3()
|
||||
|
||||
# Got no stats from S3, returning early to save runtime
|
||||
if len(jobs_to_times) == 0:
|
||||
print('Gathered no stats from S3. No new slow tests calculated.')
|
||||
return []
|
||||
|
||||
slow_tests: List[str] = []
|
||||
for test in TESTS:
|
||||
if test in jobs_to_times and test not in TARGET_DET_LIST:
|
||||
test_time, _ = jobs_to_times[test]
|
||||
if test_time > SLOW_TEST_THRESHOLD:
|
||||
slow_tests.append(test)
|
||||
return slow_tests
|
||||
|
||||
|
||||
def get_executable_command(options, allow_pytest, disable_coverage=False):
|
||||
if options.coverage and not disable_coverage:
|
||||
executable = ['coverage', 'run', '--parallel-mode', '--source=torch']
|
||||
|
|
@ -892,10 +937,10 @@ def get_dep_modules(test):
|
|||
return dep_modules
|
||||
|
||||
|
||||
def determine_target(test, touched_files, options):
|
||||
def determine_target(target_det_list, test, touched_files, options):
|
||||
test = parse_test_module(test)
|
||||
# Some tests are faster to execute than to determine.
|
||||
if test not in SLOW_TESTS:
|
||||
if test not in target_det_list:
|
||||
if options.verbose:
|
||||
print_to_stderr(f'Running {test} without determination')
|
||||
return True
|
||||
|
|
@ -975,6 +1020,9 @@ def main():
|
|||
selected_tests = filter(lambda test_name: "jit" in test_name, TESTS)
|
||||
|
||||
if options.determine_from is not None and os.path.exists(options.determine_from):
|
||||
slow_tests = get_slow_tests_based_on_S3()
|
||||
print('Added the following tests to target_det tests as calculated based on S3:')
|
||||
print(slow_tests)
|
||||
with open(options.determine_from, 'r') as fh:
|
||||
touched_files = [
|
||||
os.path.normpath(name.strip()) for name in fh.read().split('\n')
|
||||
|
|
@ -984,7 +1032,7 @@ def main():
|
|||
sys.path.append('test')
|
||||
selected_tests = [
|
||||
test for test in selected_tests
|
||||
if determine_target(test, touched_files, options)
|
||||
if determine_target(TARGET_DET_LIST + slow_tests, test, touched_files, options)
|
||||
]
|
||||
sys.path.remove('test')
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class DeterminationTest(unittest.TestCase):
|
|||
return [
|
||||
test
|
||||
for test in cls.TESTS
|
||||
if run_test.determine_target(test, changed_files, DummyOptions())
|
||||
if run_test.determine_target(run_test.TARGET_DET_LIST, test, changed_files, DummyOptions())
|
||||
]
|
||||
|
||||
def test_config_change_only(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue