mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fixes #80635 tested on pytorch-canary: https://github.com/pytorch/pytorch-canary/pull/115 steal the identity of the committer of the commit on the orig branch instead of overwriting it with mergebot's identity Pull Request resolved: https://github.com/pytorch/pytorch/pull/80747 Approved by: https://github.com/malfet, https://github.com/seemethere
137 lines
6 KiB
Python
Executable file
137 lines
6 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import re
|
|
from typing import Any
|
|
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
|
from trymerge import gh_post_pr_comment as gh_post_comment, GitHubPR
|
|
|
|
|
|
def parse_args() -> Any:
|
|
from argparse import ArgumentParser
|
|
parser = ArgumentParser("Rebase PR into branch")
|
|
parser.add_argument("--dry-run", action="store_true")
|
|
parser.add_argument("--branch", type=str)
|
|
parser.add_argument("pr_num", type=int)
|
|
return parser.parse_args()
|
|
|
|
|
|
def rebase_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False) -> None:
|
|
branch = f"pull/{pr.pr_num}/head"
|
|
onto_branch = f"refs/remotes/origin/{onto_branch}"
|
|
remote_url = f"https://github.com/{pr.info['headRepository']['nameWithOwner']}.git"
|
|
refspec = f"{branch}:{pr.head_ref()}"
|
|
|
|
repo.fetch(branch, branch)
|
|
repo._run_git("rebase", onto_branch, branch)
|
|
if dry_run:
|
|
push_result = repo._run_git("push", "--dry-run", "-f", remote_url, refspec)
|
|
else:
|
|
push_result = repo._run_git("push", "-f", remote_url, refspec)
|
|
if "Everything up-to-date" in push_result:
|
|
gh_post_comment(pr.org, pr.project, pr.pr_num,
|
|
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date", dry_run=dry_run)
|
|
else:
|
|
gh_post_comment(pr.org, pr.project, pr.pr_num,
|
|
f"Successfully rebased `{pr.head_ref()}` onto `{onto_branch}`, please pull locally " +
|
|
f"before adding more changes (for example, via `git checkout {pr.head_ref()} && " +
|
|
"git pull --rebase`)", dry_run=dry_run)
|
|
|
|
|
|
def rebase_ghstack_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False) -> None:
|
|
if subprocess.run([sys.executable, "-m", "ghstack", "--help"], capture_output=True).returncode != 0:
|
|
subprocess.run([sys.executable, "-m", "pip", "install", "ghstack"])
|
|
orig_ref = f"{re.sub(r'/head$', '/orig', pr.head_ref())}"
|
|
onto_branch = f"refs/remotes/origin/{onto_branch}"
|
|
|
|
repo.fetch(orig_ref, orig_ref)
|
|
repo._run_git("rebase", onto_branch, orig_ref)
|
|
|
|
# steal the identity of the committer of the commit on the orig branch
|
|
email = repo._run_git("log", orig_ref, "--pretty=format:%ae", "-1")
|
|
name = repo._run_git("log", orig_ref, "--pretty=format:%an", "-1")
|
|
repo._run_git("config", "--global", "user.name", name)
|
|
repo._run_git("config", "--global", "user.email", email)
|
|
|
|
os.environ["OAUTH_TOKEN"] = os.environ["GITHUB_TOKEN"]
|
|
with open('.ghstackrc', 'w+') as f:
|
|
f.write('[ghstack]\n' +
|
|
"github_url=github.com\n" +
|
|
"github_username=pytorchmergebot\n" +
|
|
"remote_name=origin")
|
|
|
|
if dry_run:
|
|
print("Don't know how to dry-run ghstack")
|
|
else:
|
|
ghstack_result = subprocess.run(["ghstack"], capture_output=True)
|
|
push_result = ghstack_result.stdout.decode("utf-8")
|
|
print(push_result)
|
|
if ghstack_result.returncode != 0:
|
|
raise Exception(f"\n```{push_result}```")
|
|
# The contents of a successful push result should look like:
|
|
# Summary of changes (ghstack 0.6.0)
|
|
|
|
# - Updated https://github.com/clee2000/random-testing/pull/2
|
|
# - Updated https://github.com/clee2000/random-testing/pull/1
|
|
|
|
# Facebook employees can import your changes by running
|
|
# (on a Facebook machine):
|
|
|
|
# ghimport -s https://github.com/clee2000/random-testing/pull/2
|
|
|
|
# If you want to work on this diff stack on another machine:
|
|
|
|
# ghstack checkout https://github.com/clee2000/random-testing/pull/2
|
|
org, project = repo.gh_owner_and_name()
|
|
for line in push_result.splitlines():
|
|
if "Updated" in line:
|
|
pr_num = int(line.split("/")[-1])
|
|
if pr_num != pr.pr_num:
|
|
gh_post_comment(pr.org, pr.project, pr_num,
|
|
f"Rebased `{orig_ref}` onto `{onto_branch}` because #{pr.pr_num} was rebased, "
|
|
"please pull locally before adding more changes (for example, via `ghstack " +
|
|
f"checkout https://github.com/{org}/{project}/pull/{pr_num}`)", dry_run=dry_run)
|
|
else:
|
|
gh_post_comment(pr.org, pr.project, pr_num,
|
|
f"Successfully rebased `{orig_ref}` onto `{onto_branch}`, please pull locally " +
|
|
"before adding more changes (for example, via `ghstack " +
|
|
f"checkout https://github.com/{org}/{project}/pull/{pr.pr_num}`)", dry_run=dry_run)
|
|
|
|
if f"Skipped https://github.com/{org}/{project}/pull/{pr.pr_num}" in push_result:
|
|
gh_post_comment(pr.org, pr.project, pr.pr_num,
|
|
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date", dry_run=dry_run)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
repo = GitRepo(get_git_repo_dir(), get_git_remote_name(), debug=True)
|
|
org, project = repo.gh_owner_and_name()
|
|
|
|
pr = GitHubPR(org, project, args.pr_num)
|
|
onto_branch = args.branch if args.branch else pr.default_branch()
|
|
|
|
msg = "@pytorchbot successfully started a rebase job."
|
|
msg += f" Check the current status [here]({os.getenv('GH_RUN_URL')})"
|
|
gh_post_comment(org, project, args.pr_num, msg, dry_run=args.dry_run)
|
|
|
|
if pr.is_closed():
|
|
gh_post_comment(org, project, args.pr_num, f"PR #{args.pr_num} is closed, won't rebase", dry_run=args.dry_run)
|
|
return
|
|
|
|
try:
|
|
if pr.is_ghstack_pr():
|
|
rebase_ghstack_onto(pr, repo, onto_branch, dry_run=args.dry_run)
|
|
return
|
|
rebase_onto(pr, repo, onto_branch, dry_run=args.dry_run)
|
|
except Exception as e:
|
|
msg = f"Rebase failed due to {e}"
|
|
run_url = os.getenv("GH_RUN_URL")
|
|
if run_url is not None:
|
|
msg += f"\nRaised by {run_url}"
|
|
gh_post_comment(org, project, args.pr_num, msg, dry_run=args.dry_run)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|