diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index cc07f35c29c..5596e36ea54 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -74,6 +74,7 @@ def mock_parse_args(revert: bool = False, self.comment_id = 0 self.on_mandatory = False self.on_green = False + self.land_checks = False self.reason = 'this is for testing' return Object() @@ -90,6 +91,7 @@ def mock_merge(pr_num: int, repo: GitRepo, comment_id: Optional[int] = None, mandatory_only: bool = False, on_green: bool = False, + land_checks: bool = False, timeout_minutes: int = 400, stale_pr_days: int = 3) -> None: pass @@ -273,6 +275,7 @@ class TestGitHubPR(TestCase): force=True, comment_id=mock.ANY, on_green=False, + land_checks=False, mandatory_only=False) @mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info()) @@ -286,6 +289,7 @@ class TestGitHubPR(TestCase): force=False, comment_id=mock.ANY, on_green=False, + land_checks=False, mandatory_only=False) if __name__ == "__main__": diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index db7910a54c1..0c406d5e61e 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -10,7 +10,7 @@ from datetime import datetime from dataclasses import dataclass from urllib.request import urlopen, Request from urllib.error import HTTPError -from typing import cast, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Iterable, cast, Any, Callable, Dict, List, Optional, Tuple, Union from gitutils import get_git_remote_name, get_git_repo_dir, patterns_to_regex, GitRepo from functools import lru_cache from warnings import warn @@ -312,7 +312,6 @@ RE_REVERT_CMD = re.compile(r"@pytorch(merge|)bot\s+revert\s+this") RE_REVERT_CMD_CLI = re.compile(r"@pytorch(merge|)bot\s+revert\s+(-m.*-c.*|-c.*-m.*)") RE_DIFF_REV = re.compile(r'^Differential Revision:.+?(D[0-9]+)', re.MULTILINE) - def _fetch_url(url: str, *, headers: Optional[Dict[str, str]] = None, data: Optional[Dict[str, Any]] = None, @@ -332,7 +331,6 @@ def _fetch_url(url: str, *, print(f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}") raise - def fetch_json(url: str, params: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: @@ -341,6 +339,13 @@ def fetch_json(url: str, url += '?' + '&'.join(f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items()) return cast(List[Dict[str, Any]], _fetch_url(url, headers=headers, data=data, reader=json.load)) +def fetch_json_dict(url: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None) -> Dict[str, Any] : + headers = {'Accept': 'application/vnd.github.v3+json'} + if params is not None and len(params) > 0: + url += '?' + '&'.join(f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items()) + return cast(Dict[str, Any], _fetch_url(url, headers=headers, data=data, reader=json.load)) def _gh_post_comment(url: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]: if dry_run: @@ -395,6 +400,7 @@ def parse_args() -> Any: parser.add_argument("--dry-run", action="store_true") parser.add_argument("--on-green", action="store_true") parser.add_argument("--on-mandatory", action="store_true") + parser.add_argument("--land-checks", action="store_true") parser.add_argument("--revert", action="store_true") parser.add_argument("--force", action="store_true") parser.add_argument("--comment-id", type=int) @@ -733,11 +739,28 @@ class GitHubPR: msg += f"Approved by: {approved_by_urls}\n" return msg - def merge_into(self, repo: GitRepo, *, force: bool = False, dry_run: bool = False, comment_id: Optional[int] = None) -> None: + def merge_into(self, repo: GitRepo, *, + force: bool = False, + dry_run: bool = False, + comment_id: Optional[int] = None) -> None: # Raises exception if matching rule is not found find_matching_merge_rule(self, repo, force=force, skip_internal_checks=can_skip_internal_checks(self, comment_id)) - if repo.current_branch() != self.default_branch(): - repo.checkout(self.default_branch()) + self.merge_changes(repo, force, comment_id) + + repo.push(self.default_branch(), dry_run) + gh_post_pr_comment(self.org, self.project, self.pr_num, + f"@{self.get_pr_creator_login()} your PR has been successfully merged.", dry_run) + if not dry_run: + gh_add_labels(self.org, self.project, self.pr_num, ["merged"]) + + def merge_changes(self, + repo: GitRepo, + force: bool = False, + comment_id: Optional[int] = None, + branch: Optional[str] = None) -> None: + branch_to_merge_into = self.default_branch() if branch is None else branch + if repo.current_branch() != branch_to_merge_into: + repo.checkout(branch_to_merge_into) if not self.is_ghstack_pr(): msg = self.gen_commit_message() pr_branch_name = f"__pull-request-{self.pr_num}__init__" @@ -747,9 +770,24 @@ class GitHubPR: else: self.merge_ghstack_into(repo, force, comment_id=comment_id) - repo.push(self.default_branch(), dry_run) - if not dry_run: - gh_add_labels(self.org, self.project, self.pr_num, ["merged"]) + def create_land_time_check_branch(self, + repo: GitRepo, + branch: str, + force: bool = False, + comment_id: Optional[int] = None,) -> str: + self.merge_changes(repo, branch=branch, force=force, comment_id=comment_id) + land_check_branch = f'landchecks/{self.pr_num}' + try: + repo._run_git('branch', "-D", land_check_branch) + except Exception: + pass + repo._run_git('checkout', "-b", land_check_branch) + repo._run_git('push', '-u', 'origin', land_check_branch, '--force') + commit = repo.get_commit('HEAD').commit_hash + gh_post_pr_comment(self.org, self.project, self.pr_num, + 'Successfully started land time checks.' + + f' See progress here: https://hud.pytorch.org/{self.org}/{self.project}/commit/{commit}') + return commit class MandatoryChecksMissingError(Exception): @@ -838,21 +876,10 @@ def find_matching_merge_rule(pr: GitHubPR, reject_reason = (f"Matched rule {rule_name}, but PR #{pr.pr_num} was not reviewed yet by any of: " + f"{', '.join(list(rule_approvers_set)[:5])}{', ...' if len(rule_approvers_set) > 5 else ''}") continue - if rule.mandatory_checks_name is not None: - pending_checks: List[Tuple[str, Optional[str]]] = [] - failed_checks: List[Tuple[str, Optional[str]]] = [] - checks = pr.get_checkrun_conclusions() - # HACK: We don't want to skip CLA check, even when forced - for checkname in filter(lambda x: force is False or "CLA Check" in x, rule.mandatory_checks_name): - if checkname not in checks: - pending_checks.append((checkname, None)) - elif checks[checkname][0] is None: - pending_checks.append((checkname, checks[checkname][1])) - elif checks[checkname][0] != 'SUCCESS': - failed_checks.append((checkname, checks[checkname][1])) - - def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str: - return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks) + mandatory_checks = rule.mandatory_checks_name if rule.mandatory_checks_name is not None else [] + checks = pr.get_checkrun_conclusions() + required_checks = filter(lambda x: force is False or "CLA Check" in x, mandatory_checks) + [pending_checks, failed_checks] = categorize_checks(checks, required_checks) if len(failed_checks) > 0: if reject_reason_score < 30000: @@ -874,6 +901,9 @@ def find_matching_merge_rule(pr: GitHubPR, raise RuntimeError(reject_reason) +def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str: + return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks) + def pr_get_checks_with_lambda(pr: GitHubPR, status_check: Callable[[Optional[str]], bool]) -> List[Tuple[str, str]]: checks = pr.get_checkrun_conclusions() return [(name, status[1]) for name, status in checks.items() if status_check(status[0])] @@ -938,7 +968,6 @@ def try_revert(repo: GitRepo, pr: GitHubPR, *, def prefix_with_github_url(suffix_str: str) -> str: return f"https://github.com/{suffix_str}" - def check_for_sev(org: str, project: str, force: bool) -> None: if force: return @@ -959,6 +988,38 @@ def check_for_sev(org: str, project: str, force: bool) -> None: ) return +def fetch_check_run_conclusions(repo: GitRepo, commit: str) -> Dict[str, Tuple[str, str]]: + [owner, name] = repo.gh_owner_and_name() + checks = fetch_json_dict(f'https://api.github.com/repos/{owner}/{name}/commits/{commit}/check-runs') + check_run_conclusions = {} + if len(checks) == 0: + raise MandatoryChecksMissingError("Refusing to merge as land check(s) are not yet run") + for check_run in checks['check_runs']: + check_run_conclusions[check_run['name']] = (check_run['conclusion'], + check_run['html_url']) + return check_run_conclusions + +def validate_land_time_checks(repo: GitRepo, commit: str) -> None: + checks = fetch_check_run_conclusions(repo, commit) + [pending_checks, failed_checks] = categorize_checks(checks, checks) + + if len(failed_checks) > 0: + raise RuntimeError(f"Failed to merge; some land checks failed: {checks_to_str(failed_checks)}") + if len(pending_checks) > 0: + raise MandatoryChecksMissingError(f"Refusing to merge as land check(s) {checks_to_str(pending_checks)} are not yet run") + +def categorize_checks(check_runs: Dict[str, Tuple[str, str]], + required_checks: Iterable[str]) -> Tuple[List[Tuple[str, Optional[str]]], List[Tuple[str, Optional[str]]]]: + pending_checks: List[Tuple[str, Optional[str]]] = [] + failed_checks: List[Tuple[str, Optional[str]]] = [] + for checkname in required_checks: + if checkname not in check_runs: + pending_checks.append((checkname, None)) + elif check_runs[checkname][0] is None: + pending_checks.append((checkname, check_runs[checkname][1])) + elif check_runs[checkname][0].upper() != 'SUCCESS' and check_runs[checkname][0].upper() != 'SKIPPED': + failed_checks.append((checkname, check_runs[checkname][1])) + return (pending_checks, failed_checks) def merge(pr_num: int, repo: GitRepo, dry_run: bool = False, @@ -966,6 +1027,7 @@ def merge(pr_num: int, repo: GitRepo, comment_id: Optional[int] = None, mandatory_only: bool = False, on_green: bool = False, + land_checks: bool = False, timeout_minutes: int = 400, stale_pr_days: int = 3) -> None: repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) @@ -979,6 +1041,9 @@ def merge(pr_num: int, repo: GitRepo, if (datetime.utcnow() - pr.last_pushed_at()).days > stale_pr_days: raise RuntimeError("This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again.") + if land_checks: + commit = pr.create_land_time_check_branch(repo, 'viable/strict', force=force, comment_id=comment_id) + start_time = time.time() last_exception = '' elapsed_time = 0.0 @@ -994,12 +1059,16 @@ def merge(pr_num: int, repo: GitRepo, find_matching_merge_rule(pr, repo) pending = pr_get_pending_checks(pr) failing = pr_get_failed_checks(pr) + if (not mandatory_only and on_green) and len(failing) > 0: raise RuntimeError(f"{len(failing)} additional jobs have failed, first few of them are: " + ' ,'.join(f"[{x[0]}]({x[1]})" for x in failing[:5])) if (not mandatory_only and on_green) and len(pending) > 0: raise MandatoryChecksMissingError(f"Still waiting for {len(pending)} additional jobs to finish, " + f"first few of them are: {' ,'.join(x[0] for x in pending[:5])}") + if land_checks: + validate_land_time_checks(repo, commit) + return pr.merge_into(repo, dry_run=dry_run, force=force, comment_id=comment_id) except MandatoryChecksMissingError as ex: last_exception = str(ex) @@ -1052,11 +1121,11 @@ def main() -> None: force=args.force, comment_id=args.comment_id, on_green=args.on_green, - mandatory_only=args.on_mandatory) + mandatory_only=args.on_mandatory, + land_checks=args.land_checks) except Exception as e: handle_exception(e) - if __name__ == "__main__": main() diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6417eb26a86..84dcad0f44a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,6 +7,7 @@ on: - master - main - release/* + - landchecks/* workflow_dispatch: jobs: diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 2cb956c107f..298a48fde2d 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -7,6 +7,7 @@ on: - master - main - release/* + - landchecks/* workflow_dispatch: concurrency: diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 1b0f9fec9e2..8adfba9985d 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -6,6 +6,7 @@ on: - master - main - release/* + - landchecks/* tags: - ciflow/trunk/* workflow_dispatch: diff --git a/.github/workflows/trymerge.yml b/.github/workflows/trymerge.yml index 99cea22e3f2..8db7b0c97c5 100644 --- a/.github/workflows/trymerge.yml +++ b/.github/workflows/trymerge.yml @@ -31,6 +31,7 @@ jobs: GH_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} FORCE: ${{ github.event.client_payload.force}} ON_GREEN: ${{ github.event.client_payload.on_green}} + LAND_CHECKS: ${{ github.event.client_payload.land_checks }} COMMENT_ID: ${{ github.event.client_payload.comment_id }} run: | set -ex @@ -42,6 +43,8 @@ jobs: fi elif [ -n "${ON_GREEN}" ]; then python3 .github/scripts/trymerge.py --on-green "${PR_NUM}" + elif [ -n "${LAND_CHECKS}" ]; then + python3 .github/scripts/trymerge.py --land-checks "${PR_NUM}" elif [ -n "${COMMENT_ID}" ]; then python3 .github/scripts/trymerge.py --comment-id "${COMMENT_ID}" "${PR_NUM}" else