Skip to content

Commit

Permalink
get_pr_number_from_workflow_run > get_branch_from_workflow_run + find…
Browse files Browse the repository at this point in the history
…_pr_for_branch
  • Loading branch information
ewjoachim committed Sep 3, 2023
1 parent 550444d commit 1680c53
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 65 deletions.
65 changes: 35 additions & 30 deletions coverage_comment/github.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import functools
import io
import json
import pathlib
Expand Down Expand Up @@ -73,47 +72,53 @@ def download_artifact(
raise NoArtifact(f"File named {filename} not found in artifact {artifact_name}")


def get_pr_number_from_workflow_run(
def get_branch_from_workflow_run(
github: github_client.GitHub, repository: str, run_id: int
) -> int:
# It's quite horrendous to access the PR number from a workflow run,
# especially when it's not the "pull_request" workflow run itself but a
# "workflow_run" workflow run that runs after the "pull_request" workflow
# run.
#
# 1. We need the user to give us access to the "pull_request" workflow run
# id. That's why we request to be sent the following as input:
# GITHUB_PR_RUN_ID: ${{ github.event.workflow_run.id }}
# 2. From that run, we get the corresponding branch, and the owner of the branch
# 3. We list open PRs that have that branch as head branch. There should be only
# one.
# 4. If there's no open PRs, we look at all PRs. We take the most recently
# updated one

) -> tuple[str, str]:
repo_path = github.repos(repository)
run = repo_path.actions.runs(run_id).get()
branch = run.head_branch
repo_name = run.head_repository.full_name
full_branch = f"{repo_name}:{branch}"
get_prs = functools.partial(
repo_path.pulls.get,
head=full_branch,
sort="updated",
direction="desc",
)
owner = run.head_repository.owner.login
return owner, branch


def find_pr_for_branch(
github: github_client.GitHub, repository: str, owner: str, branch: str
) -> int:
# The full branch is in the form of "owner:branch" as specified in
# https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#list-pull-requests
# but it seems to also work with "owner/repo:branch"

full_branch = f"{owner}:{branch}"

common_kwargs = {"head": full_branch, "sort": "updated", "direction": "desc"}
try:
return next(iter(pr.number for pr in get_prs(state="open")))
return next(
iter(
pr.number
for pr in github.repos(repository).pulls.get(
state="open", **common_kwargs
)
)
)
except StopIteration:
pass
log.info(f"No open PR found for branch {full_branch}, defaulting to all PRs")
log.info(f"No open PR found for branch {branch}, defaulting to all PRs")

try:
return next(iter(pr.number for pr in get_prs(state="all")))
return next(
iter(
pr.number
for pr in github.repos(repository).pulls.get(
state="all", **common_kwargs
)
)
)
except StopIteration:
raise CannotDeterminePR(f"No open PR found for branch {full_branch}")
raise CannotDeterminePR(f"No open PR found for branch {branch}")


def get_my_login(github: github_client.GitHub):
def get_my_login(github: github_client.GitHub) -> str:
try:
response = github.user.get()
except github_client.Forbidden:
Expand Down
78 changes: 43 additions & 35 deletions tests/integration/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,24 @@ def test_download_artifact__no_file(gh, session, zip_bytes):
)


def test_get_pr_number_from_workflow_run(gh, session):
def test_get_branch_from_workflow_run(gh, session):
json = {
"head_branch": "other",
"head_repository": {"full_name": "someone/repo-name"},
"head_repository": {"owner": {"login": "someone"}},
}
session.register("GET", "/repos/foo/bar/actions/runs/123")(json=json)

owner, branch = github.get_branch_from_workflow_run(
github=gh, repository="foo/bar", run_id=123
)

assert owner == "someone"
assert branch == "other"


def test_find_pr_for_branch(gh, session):
params = {
"head": "someone/repo-name:other",
"head": "someone:other",
"sort": "updated",
"direction": "desc",
"state": "open",
Expand All @@ -123,58 +133,56 @@ def test_get_pr_number_from_workflow_run(gh, session):
json=[{"number": 456}]
)

result = github.get_pr_number_from_workflow_run(
github=gh, repository="foo/bar", run_id=123
result = github.find_pr_for_branch(
github=gh, repository="foo/bar", owner="someone", branch="other"
)

assert result == 456


def test_get_pr_number_from_workflow_run__no_open_pr(gh, session):
json = {
"head_branch": "other",
"head_repository": {"full_name": "someone/repo-name"},
}
session.register("GET", "/repos/foo/bar/actions/runs/123")(json=json)
def test_find_pr_for_branch__no_open_pr(gh, session):
params = {
"head": "someone/repo-name:other",
"head": "someone:other",
"sort": "updated",
"direction": "desc",
}
session.register("GET", "/repos/foo/bar/pulls", params=params | {"state": "open"})(
json=[]
)
session.register("GET", "/repos/foo/bar/pulls", params=params | {"state": "all"})(
json=[{"number": 456}]
)
session.register(
"GET",
"/repos/foo/bar/pulls",
params=params | {"state": "open"},
)(json=[])
session.register(
"GET",
"/repos/foo/bar/pulls",
params=params | {"state": "all"},
)(json=[{"number": 456}])

result = github.get_pr_number_from_workflow_run(
github=gh, repository="foo/bar", run_id=123
result = github.find_pr_for_branch(
github=gh, repository="foo/bar", owner="someone", branch="other"
)

assert result == 456


def test_get_pr_number_from_workflow_run__no_pr(gh, session):
json = {
"head_branch": "other",
"head_repository": {"full_name": "someone/repo-name"},
}
session.register("GET", "/repos/foo/bar/actions/runs/123")(json=json)
def test_find_pr_for_branch__no_pr(gh, session):
params = {
"head": "someone/repo-name:other",
"head": "someone:other",
"sort": "updated",
"direction": "desc",
}
session.register("GET", "/repos/foo/bar/pulls", params=params | {"state": "open"})(
json=[]
)
session.register("GET", "/repos/foo/bar/pulls", params=params | {"state": "all"})(
json=[]
)
session.register(
"GET",
"/repos/foo/bar/pulls",
params=params | {"state": "open"},
)(json=[])
session.register(
"GET",
"/repos/foo/bar/pulls",
params=params | {"state": "all"},
)(json=[])
with pytest.raises(github.CannotDeterminePR):
github.get_pr_number_from_workflow_run(
github=gh, repository="foo/bar", run_id=123
github.find_pr_for_branch(
github=gh, repository="foo/bar", owner="someone", branch="other"
)


Expand Down

0 comments on commit 1680c53

Please sign in to comment.