Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run PR activities on push on non-default branch #248

Merged
merged 9 commits into from
Sep 13, 2023
28 changes: 28 additions & 0 deletions coverage_comment/activity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
This module is responsible for identifying what the action should be doing
based on the github event type and repository.

The code in main should be as straightforward as possible, we're offloading
the branching logic to this module.
"""


class ActivityNotFound(Exception):
pass


def find_activity(
event_name: str,
is_default_branch: bool,
) -> str:
"""Find the activity to perform based on the event type and payload."""
if event_name == "workflow_run":
return "post_comment"

if event_name == "push" and is_default_branch:
return "save_coverage_data_files"

if event_name not in {"pull_request", "push"}:
raise ActivityNotFound

return "process_pr"
65 changes: 35 additions & 30 deletions coverage_comment/github.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import functools
import io
import json
import pathlib
Expand Down Expand Up @@ -73,47 +72,53 @@ def download_artifact(
raise NoArtifact(f"File named {filename} not found in artifact {artifact_name}")


def get_pr_number_from_workflow_run(
def get_branch_from_workflow_run(
github: github_client.GitHub, repository: str, run_id: int
) -> int:
# It's quite horrendous to access the PR number from a workflow run,
# especially when it's not the "pull_request" workflow run itself but a
# "workflow_run" workflow run that runs after the "pull_request" workflow
# run.
#
# 1. We need the user to give us access to the "pull_request" workflow run
# id. That's why we request to be sent the following as input:
# GITHUB_PR_RUN_ID: ${{ github.event.workflow_run.id }}
# 2. From that run, we get the corresponding branch, and the owner of the branch
# 3. We list open PRs that have that branch as head branch. There should be only
# one.
# 4. If there's no open PRs, we look at all PRs. We take the most recently
# updated one

) -> tuple[str, str]:
repo_path = github.repos(repository)
run = repo_path.actions.runs(run_id).get()
branch = run.head_branch
repo_name = run.head_repository.full_name
full_branch = f"{repo_name}:{branch}"
get_prs = functools.partial(
repo_path.pulls.get,
head=full_branch,
sort="updated",
direction="desc",
)
owner = run.head_repository.owner.login
return owner, branch


def find_pr_for_branch(
github: github_client.GitHub, repository: str, owner: str, branch: str
) -> int:
# The full branch is in the form of "owner:branch" as specified in
# https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#list-pull-requests
# but it seems to also work with "owner/repo:branch"

full_branch = f"{owner}:{branch}"

common_kwargs = {"head": full_branch, "sort": "updated", "direction": "desc"}
try:
return next(iter(pr.number for pr in get_prs(state="open")))
return next(
iter(
pr.number
for pr in github.repos(repository).pulls.get(
state="open", **common_kwargs
)
)
)
except StopIteration:
pass
log.info(f"No open PR found for branch {full_branch}, defaulting to all PRs")
log.info(f"No open PR found for branch {branch}, defaulting to all PRs")

try:
return next(iter(pr.number for pr in get_prs(state="all")))
return next(
iter(
pr.number
for pr in github.repos(repository).pulls.get(
state="all", **common_kwargs
)
)
)
except StopIteration:
raise CannotDeterminePR(f"No open PR found for branch {full_branch}")
raise CannotDeterminePR(f"No open PR found for branch {branch}")


def get_my_login(github: github_client.GitHub):
def get_my_login(github: github_client.GitHub) -> str:
try:
response = github.user.get()
except github_client.Forbidden:
Expand Down
158 changes: 94 additions & 64 deletions coverage_comment/main.py
kieferro marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import httpx

from coverage_comment import activity as activity_module
from coverage_comment import annotations, comment_file, communication
from coverage_comment import coverage as coverage_module
from coverage_comment import (
Expand Down Expand Up @@ -60,69 +61,79 @@ def action(
git: subprocess.Git,
) -> int:
log.debug(f"Operating on {config.GITHUB_REF}")

gh = github_client.GitHub(session=github_session)
event_name = config.GITHUB_EVENT_NAME
if event_name not in {"pull_request", "push", "workflow_run"}:
repo_info = github.get_repository_info(
github=gh, repository=config.GITHUB_REPOSITORY
)
try:
activity = activity_module.find_activity(
event_name=event_name,
is_default_branch=repo_info.is_default_branch(ref=config.GITHUB_REF),
)
except activity_module.ActivityNotFound:
log.error(
'This action has only been designed to work for "pull_request", "branch" '
'This action has only been designed to work for "pull_request", "push" '
f'or "workflow_run" actions, not "{event_name}". Because there are security '
"implications. If you have a different usecase, please open an issue, "
"we'll be glad to add compatibility."
)
return 1

if event_name in {"pull_request", "push"}:
raw_coverage, coverage = coverage_module.get_coverage_info(
merge=config.MERGE_COVERAGE_FILES, coverage_path=config.COVERAGE_PATH
if activity == "save_coverage_data_files":
return save_coverage_data_files(
config=config,
git=git,
http_session=http_session,
repo_info=repo_info,
)

elif activity == "process_pr":
return process_pr(
config=config,
gh=gh,
repo_info=repo_info,
)
if event_name == "pull_request":
diff_coverage = coverage_module.get_diff_coverage_info(
base_ref=config.GITHUB_BASE_REF, coverage_path=config.COVERAGE_PATH
)
if config.ANNOTATE_MISSING_LINES:
annotations.create_pr_annotations(
annotation_type=config.ANNOTATION_TYPE, diff_coverage=diff_coverage
)
return generate_comment(
config=config,
coverage=coverage,
diff_coverage=diff_coverage,
github_session=github_session,
)
else:
# event_name == "push"
return save_coverage_data_files(
config=config,
coverage=coverage,
raw_coverage_data=raw_coverage,
github_session=github_session,
git=git,
http_session=http_session,
)

else:
# event_name == "workflow_run"
# activity == "post_comment":
return post_comment(
config=config,
github_session=github_session,
gh=gh,
)


def generate_comment(
def process_pr(
config: settings.Config,
coverage: coverage_module.Coverage,
diff_coverage: coverage_module.DiffCoverage,
github_session: httpx.Client,
gh: github_client.GitHub,
repo_info: github.RepositoryInfo,
) -> int:
log.info("Generating comment for PR")

gh = github_client.GitHub(session=github_session)

previous_coverage_data_file = storage.get_datafile_contents(
github=gh,
repository=config.GITHUB_REPOSITORY,
branch=config.COVERAGE_DATA_BRANCH,
_, coverage = coverage_module.get_coverage_info(
merge=config.MERGE_COVERAGE_FILES,
coverage_path=config.COVERAGE_PATH,
)
base_ref = config.GITHUB_BASE_REF or repo_info.default_branch
diff_coverage = coverage_module.get_diff_coverage_info(
base_ref=base_ref, coverage_path=config.COVERAGE_PATH
)

# It only really makes sense to display a comparison with the previous
# coverage if the PR target is the branch in which the coverage data is
# stored, e.g. the default branch.
# In the case we're running on a branch without a PR yet, we can't know
# if it's going to target the default branch, so we display it.
previous_coverage_data_file = None
pr_targets_default_branch = base_ref == repo_info.default_branch

if pr_targets_default_branch:
previous_coverage_data_file = storage.get_datafile_contents(
github=gh,
repository=config.GITHUB_REPOSITORY,
branch=config.COVERAGE_DATA_BRANCH,
)

previous_coverage = None
if previous_coverage_data_file:
previous_coverage = files.parse_datafile(contents=previous_coverage_data_file)
Expand All @@ -134,6 +145,7 @@ def generate_comment(
previous_coverage_rate=previous_coverage,
base_template=template.read_template_file("comment.md.j2"),
custom_template=config.COMMENT_TEMPLATE,
pr_targets_default_branch=pr_targets_default_branch,
)
except template.MissingMarker:
log.error(
Expand All @@ -152,21 +164,39 @@ def generate_comment(
)
return 1

assert config.GITHUB_PR_NUMBER

github.add_job_summary(
content=comment, github_step_summary=config.GITHUB_STEP_SUMMARY
)
pr_number: int | None = config.GITHUB_PR_NUMBER
if pr_number is None:
# If we don't have a PR number, we're launched from a push event,
# so we need to find the PR number from the branch name
assert config.GITHUB_BRANCH_NAME
try:
pr_number = github.find_pr_for_branch(
github=gh,
# A push event cannot be initiated from a forked repository
repository=config.GITHUB_REPOSITORY,
owner=config.GITHUB_REPOSITORY.split("/")[0],
branch=config.GITHUB_BRANCH_NAME,
)
except github.CannotDeterminePR:
pr_number = None
ewjoachim marked this conversation as resolved.
Show resolved Hide resolved

if pr_number is not None and config.ANNOTATE_MISSING_LINES:
annotations.create_pr_annotations(
annotation_type=config.ANNOTATION_TYPE, diff_coverage=diff_coverage
)

try:
if config.FORCE_WORKFLOW_RUN:
if config.FORCE_WORKFLOW_RUN or not pr_number:
raise github.CannotPostComment

github.post_comment(
github=gh,
me=github.get_my_login(github=gh),
repository=config.GITHUB_REPOSITORY,
pr_number=config.GITHUB_PR_NUMBER,
pr_number=pr_number,
contents=comment,
marker=template.MARKER,
)
Expand All @@ -193,21 +223,29 @@ def generate_comment(
return 0


def post_comment(config: settings.Config, github_session: httpx.Client) -> int:
def post_comment(
config: settings.Config,
gh: github_client.GitHub,
) -> int:
log.info("Posting comment to PR")

if not config.GITHUB_PR_RUN_ID:
log.error("Missing input GITHUB_PR_RUN_ID. Please consult the documentation.")
return 1

gh = github_client.GitHub(session=github_session)
me = github.get_my_login(github=gh)
log.info(f"Search for PR associated with run id {config.GITHUB_PR_RUN_ID}")
owner, branch = github.get_branch_from_workflow_run(
github=gh,
run_id=config.GITHUB_PR_RUN_ID,
repository=config.GITHUB_REPOSITORY,
)
try:
pr_number = github.get_pr_number_from_workflow_run(
pr_number = github.find_pr_for_branch(
github=gh,
run_id=config.GITHUB_PR_RUN_ID,
repository=config.GITHUB_REPOSITORY,
owner=owner,
branch=branch,
)
except github.CannotDeterminePR:
log.error(
Expand Down Expand Up @@ -250,25 +288,17 @@ def post_comment(config: settings.Config, github_session: httpx.Client) -> int:

def save_coverage_data_files(
config: settings.Config,
coverage: coverage_module.Coverage,
raw_coverage_data: dict,
github_session: httpx.Client,
git: subprocess.Git,
http_session: httpx.Client,
repo_info: github.RepositoryInfo,
) -> int:
gh = github_client.GitHub(session=github_session)
repo_info = github.get_repository_info(
github=gh,
repository=config.GITHUB_REPOSITORY,
)
is_default_branch = repo_info.is_default_branch(ref=config.GITHUB_REF)
log.debug(f"On default branch: {is_default_branch}")
log.info("Computing coverage files & badge")

if not is_default_branch:
log.info("Skipping badge save as we're not on the default branch")
return 0
raw_coverage_data, coverage = coverage_module.get_coverage_info(
merge=config.MERGE_COVERAGE_FILES,
coverage_path=config.COVERAGE_PATH,
)

log.info("Computing coverage files & badge")
operations: list[files.Operation] = files.compute_files(
line_rate=coverage.info.percent_covered,
raw_coverage_data=raw_coverage_data,
Expand Down
13 changes: 13 additions & 0 deletions coverage_comment/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@ def str_to_bool(value: str) -> bool:
class Config:
"""This object defines the environment variables"""

# A branch name, not a fully-formed ref. For example, `main`.
GITHUB_BASE_REF: str
GITHUB_TOKEN: str = dataclasses.field(repr=False)
GITHUB_REPOSITORY: str
# > The ref given is fully-formed, meaning that for branches the format is
# > `refs/heads/<branch_name>`, for pull requests it is
# > `refs/pull/<pr_number>/merge`, and for tags it is `refs/tags/<tag_name>`.
# > For example, `refs/heads/feature-branch-1`.
# (from https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables )
GITHUB_REF: str
GITHUB_EVENT_NAME: str
GITHUB_PR_RUN_ID: int | None
Expand Down Expand Up @@ -119,6 +125,13 @@ def GITHUB_PR_NUMBER(self) -> int | None:
return int(self.GITHUB_REF.split("/")[2])
return None

@property
def GITHUB_BRANCH_NAME(self) -> str | None:
# "refs/head/my_branch_name"
if "heads" in self.GITHUB_REF:
return self.GITHUB_REF.split("/", 2)[2]
return None

# We need to type environ as a MutableMapping because that's what
# os.environ is, and just saying `dict[str, str]` is not enough to make
# mypy happy
Expand Down
Loading
Loading