pytorch/.github/scripts/pytest_cache.py
Zain Rizvi 96f46316c9 Preserve PyTest Cache across job runs (#100522)
Preserves the PyTest cache from one job run to the next.  In a later PR, this will be used to change the order in which we actually run those tests

The process is:
1. Before running tests, check S3 to see if there is an uploaded cache from any shard of the current job
2. If there are, download them all and merge their contents. Put the merged cache in the default .pytest_cache folder
3. After running the tests, merge the now-current .pytest_cache folder with the cache previously downloaded for the current shard. This will make the merged cache contain all tests that have ever failed for the given PR in the current shard
4. Upload the resulting cache file back to S3

The S3 folder has a retention policy of 30 days, after which the uploaded cache files will get auto-deleted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100522
Approved by: https://github.com/huydhn
2023-05-10 18:37:28 +00:00

106 lines
2.9 KiB
Python

import argparse
import sys
from pathlib import Path
from pytest_caching_utils import (
download_pytest_cache,
GithubRepo,
PRIdentifier,
upload_pytest_cache,
)
TEMP_DIR = "./tmp" # a backup location in case one isn't provided
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Upload this job's the pytest cache to S3"
)
mode = parser.add_mutually_exclusive_group(required=True)
mode.add_argument(
"--upload", action="store_true", help="Upload the pytest cache to S3"
)
mode.add_argument(
"--download",
action="store_true",
help="Download the pytest cache from S3, merging it with any local cache",
)
parser.add_argument(
"--cache_dir",
required=True,
help="Path to the folder pytest uses for its cache",
)
parser.add_argument("--pr_identifier", required=True, help="A unique PR identifier")
parser.add_argument(
"--job_identifier",
required=True,
help="A unique job identifier that should be the same for all runs of job",
)
parser.add_argument(
"--shard", required="--upload" in sys.argv, help="The shard id"
) # Only required for upload
parser.add_argument(
"--repo",
required=False,
help="The github repository we're running in, in the format 'owner/repo-name'",
)
parser.add_argument(
"--temp_dir", required=False, help="Directory to store temp files"
)
parser.add_argument(
"--bucket", required=False, help="The S3 bucket to upload the cache to"
)
args = parser.parse_args()
return args
def main() -> None:
args = parse_args()
pr_identifier = PRIdentifier(args.pr_identifier)
print(f"PR identifier for `{args.pr_identifier}` is `{pr_identifier}`")
repo = GithubRepo.from_string(args.repo)
cache_dir = Path(args.cache_dir)
if args.temp_dir:
temp_dir = Path(args.temp_dir)
else:
temp_dir = Path(TEMP_DIR)
if args.upload:
print(f"Uploading cache with args {args}")
# verify the cache dir exists
if not cache_dir.exists():
print(f"The pytest cache dir `{cache_dir}` does not exist. Skipping upload")
return
upload_pytest_cache(
pr_identifier=pr_identifier,
repo=repo,
job_identifier=args.job_identifier,
shard=args.shard,
cache_dir=cache_dir,
bucket=args.bucket,
temp_dir=temp_dir,
)
if args.download:
print(f"Downloading cache with args {args}")
download_pytest_cache(
pr_identifier=pr_identifier,
repo=repo,
job_identifier=args.job_identifier,
dest_cache_dir=cache_dir,
bucket=args.bucket,
temp_dir=temp_dir,
)
if __name__ == "__main__":
main()