Use default build env and test config for test times (#107325)

Redo of #107312

Pairs with https://github.com/pytorch/test-infra/pull/4476

If build env and test config combo cannot be found in the test times, use default.  Then we don't have to go manually change the test-times.json a new job is added or we update the jobs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107325
Approved by: https://github.com/huydhn
This commit is contained in:
Catherine Lee 2023-08-21 18:39:55 +00:00 committed by PyTorch MergeBot
parent ad07a4bc56
commit 3b2c5d47c0
2 changed files with 23 additions and 28 deletions

View file

@ -1418,21 +1418,30 @@ def get_selected_tests(options) -> List[ShardedTest]:
def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
# Download previous test times to make sharding decisions
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
if os.path.exists(path):
with open(path) as f:
test_file_times = cast(Dict[str, Any], json.load(f))
else:
test_file_times = {}
test_config = os.environ.get("TEST_CONFIG")
if test_config not in test_file_times:
print(
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
)
path = os.path.join(str(REPO_ROOT), file)
if not os.path.exists(path):
print("::warning:: Failed to find test times file. Using round robin sharding.")
return {}
with open(path) as f:
test_times_file = cast(Dict[str, Any], json.load(f))
build_environment = os.environ.get("BUILD_ENVIRONMENT")
test_config = os.environ.get("TEST_CONFIG")
if test_config in test_times_file.get(build_environment, {}):
print("Found test times from artifacts")
return test_times_file[build_environment][test_config]
elif test_config in test_times_file["default"]:
print(
f"::warning:: Gathered no stats from artifacts for {build_environment} build env"
f" and {test_config} test config. Using default build env and {test_config} test config instead."
)
return test_times_file["default"][test_config]
else:
print("Found test time stats from artifacts")
return test_file_times[test_config]
print(
f"::warning:: Gathered no stats from artifacts for build env {build_environment} build env"
f" and {test_config} test config. Using default build env and default test config instead."
)
return test_times_file["default"]["default"]
def do_sharding(

View file

@ -76,22 +76,8 @@ def get_slow_tests(
def get_test_times(dirpath: str, filename: str) -> Dict[str, Dict[str, float]]:
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/test-times.json"
build_environment = os.environ.get("BUILD_ENVIRONMENT")
if build_environment is None:
test_times = fetch_and_cache(dirpath, filename, url, lambda x: x)
raise RuntimeError(
f"BUILD_ENVIRONMENT is not defined, available keys are {test_times.keys()}"
)
def process_response(the_response: Dict[str, Any]) -> Any:
if build_environment not in the_response:
raise RuntimeError(
f"{build_environment} not found, available envs are: {the_response.keys()}"
)
return the_response[build_environment]
try:
return fetch_and_cache(dirpath, filename, url, process_response)
return fetch_and_cache(dirpath, filename, url, lambda x: x)
except Exception:
print("Couldn't download test times...")
return {}