From 3b2c5d47c0c5a81737a6200484f5c11f8ecadb03 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 21 Aug 2023 18:39:55 +0000 Subject: [PATCH] Use default build env and test config for test times (#107325) Redo of #107312 Pairs with https://github.com/pytorch/test-infra/pull/4476 If build env and test config combo cannot be found in the test times, use default. Then we don't have to go manually change the test-times.json a new job is added or we update the jobs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107325 Approved by: https://github.com/huydhn --- test/run_test.py | 35 ++++++++++++++++++++------------ tools/stats/import_test_stats.py | 16 +-------------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index 1931367710e..49dac330c3e 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1418,21 +1418,30 @@ def get_selected_tests(options) -> List[ShardedTest]: def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]: # Download previous test times to make sharding decisions - path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE) - if os.path.exists(path): - with open(path) as f: - test_file_times = cast(Dict[str, Any], json.load(f)) - else: - test_file_times = {} - test_config = os.environ.get("TEST_CONFIG") - if test_config not in test_file_times: - print( - "::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan." - ) + path = os.path.join(str(REPO_ROOT), file) + if not os.path.exists(path): + print("::warning:: Failed to find test times file. Using round robin sharding.") return {} + + with open(path) as f: + test_times_file = cast(Dict[str, Any], json.load(f)) + build_environment = os.environ.get("BUILD_ENVIRONMENT") + test_config = os.environ.get("TEST_CONFIG") + if test_config in test_times_file.get(build_environment, {}): + print("Found test times from artifacts") + return test_times_file[build_environment][test_config] + elif test_config in test_times_file["default"]: + print( + f"::warning:: Gathered no stats from artifacts for {build_environment} build env" + f" and {test_config} test config. Using default build env and {test_config} test config instead." + ) + return test_times_file["default"][test_config] else: - print("Found test time stats from artifacts") - return test_file_times[test_config] + print( + f"::warning:: Gathered no stats from artifacts for build env {build_environment} build env" + f" and {test_config} test config. Using default build env and default test config instead." + ) + return test_times_file["default"]["default"] def do_sharding( diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py index a0c01905807..b7617eb765f 100644 --- a/tools/stats/import_test_stats.py +++ b/tools/stats/import_test_stats.py @@ -76,22 +76,8 @@ def get_slow_tests( def get_test_times(dirpath: str, filename: str) -> Dict[str, Dict[str, float]]: url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/test-times.json" - build_environment = os.environ.get("BUILD_ENVIRONMENT") - if build_environment is None: - test_times = fetch_and_cache(dirpath, filename, url, lambda x: x) - raise RuntimeError( - f"BUILD_ENVIRONMENT is not defined, available keys are {test_times.keys()}" - ) - - def process_response(the_response: Dict[str, Any]) -> Any: - if build_environment not in the_response: - raise RuntimeError( - f"{build_environment} not found, available envs are: {the_response.keys()}" - ) - return the_response[build_environment] - try: - return fetch_and_cache(dirpath, filename, url, process_response) + return fetch_and_cache(dirpath, filename, url, lambda x: x) except Exception: print("Couldn't download test times...") return {}