pytorch/.github/scripts/pytest_cache.py
Catherine Lee dab272eed8 [td] Consistent pytest cache (#113804)
Move the pytest cache downloading into the build step and store it in additional ci files so that it stays consistent during sharding.

Only build env is taken into account now instead of also test config since we might not have the test config during build time, making it less specific, but I also think this might be better since tests are likely to fail across the same test config (I also think it might be worth not even looking at build env but thats a different topic)

Each cache upload should only include information from the current run.  Do not merge current cache with downloaded cache during upload (shouldn't matter anyways since the downloaded cache won't exist at the time)

From what I cant tell of the s3 retention policy, pytest cache files will be deleted after 30 days (cc @ZainRizvi to confirm), so we never have to worry about space or pulling old versions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113804
Approved by: https://github.com/ZainRizvi
2023-11-17 23:45:47 +00:00

114 lines
3.2 KiB
Python

import argparse
import sys
from pathlib import Path
from pytest_caching_utils import (
download_pytest_cache,
GithubRepo,
PRIdentifier,
upload_pytest_cache,
)
TEMP_DIR = "./tmp" # a backup location in case one isn't provided
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Upload this job's the pytest cache to S3"
)
mode = parser.add_mutually_exclusive_group(required=True)
mode.add_argument(
"--upload", action="store_true", help="Upload the pytest cache to S3"
)
mode.add_argument(
"--download",
action="store_true",
help="Download the pytest cache from S3, merging it with any local cache",
)
parser.add_argument(
"--cache_dir",
required=True,
help="Path to the folder pytest uses for its cache",
)
parser.add_argument("--pr_identifier", required=True, help="A unique PR identifier")
parser.add_argument(
"--job_identifier",
required=True,
help="A unique job identifier that should be the same for all runs of job",
)
parser.add_argument(
"--sha", required="--upload" in sys.argv, help="SHA of the commit"
) # Only required for upload
parser.add_argument(
"--test_config", required="--upload" in sys.argv, help="The test config"
) # Only required for upload
parser.add_argument(
"--shard", required="--upload" in sys.argv, help="The shard id"
) # Only required for upload
parser.add_argument(
"--repo",
required=False,
help="The github repository we're running in, in the format 'owner/repo-name'",
)
parser.add_argument(
"--temp_dir", required=False, help="Directory to store temp files"
)
parser.add_argument(
"--bucket", required=False, help="The S3 bucket to upload the cache to"
)
args = parser.parse_args()
return args
def main() -> None:
args = parse_args()
pr_identifier = PRIdentifier(args.pr_identifier)
print(f"PR identifier for `{args.pr_identifier}` is `{pr_identifier}`")
repo = GithubRepo.from_string(args.repo)
cache_dir = Path(args.cache_dir)
if args.temp_dir:
temp_dir = Path(args.temp_dir)
else:
temp_dir = Path(TEMP_DIR)
if args.upload:
print(f"Uploading cache with args {args}")
# verify the cache dir exists
if not cache_dir.exists():
print(f"The pytest cache dir `{cache_dir}` does not exist. Skipping upload")
return
upload_pytest_cache(
pr_identifier=pr_identifier,
repo=repo,
job_identifier=args.job_identifier,
sha=args.sha,
test_config=args.test_config,
shard=args.shard,
cache_dir=cache_dir,
bucket=args.bucket,
temp_dir=temp_dir,
)
if args.download:
print(f"Downloading cache with args {args}")
download_pytest_cache(
pr_identifier=pr_identifier,
repo=repo,
job_identifier=args.job_identifier,
dest_cache_dir=cache_dir,
bucket=args.bucket,
temp_dir=temp_dir,
)
if __name__ == "__main__":
main()