Skip to content

Commit

Permalink
ci: fix determine PR's diff (#2894)
Browse files Browse the repository at this point in the history
* fix determine PR's diff
* linter

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Jan 6, 2025
1 parent ec3fbf2 commit d9697b6
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
48 changes: 33 additions & 15 deletions .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import fire
from packaging.version import parse
from pkg_resources import parse_requirements

_REQUEST_TIMEOUT = 10
_PATH_ROOT = os.path.dirname(os.path.dirname(__file__))
Expand Down Expand Up @@ -58,6 +57,9 @@ def set_min_torch_by_python(fpath: str = "requirements/base.txt") -> None:
>>> AssistantCLI.set_min_torch_by_python("../requirements/base.txt")
"""
# ToDo: `pkg_resources` is deprecates and shall be updated
from pkg_resources import parse_requirements

py_ver = f"{sys.version_info.major}.{sys.version_info.minor}"
if py_ver not in LUT_PYTHON_TORCH:
return
Expand Down Expand Up @@ -114,29 +116,42 @@ def changed_domains(
"""Determine what domains were changed in particular PR."""
import github

# define some edge case return cases
_return_all = "unittests" if not as_list else ["torchmetrics"]
_return_empty = [] if as_list else ""

# early return if no PR number
if not pr:
return "unittests"
gh = github.Github()
return _return_all
gh = github.Github(login_or_token=auth_token)
pr = gh.get_repo("Lightning-AI/torchmetrics").get_pull(pr)
files = [f.filename for f in pr.get_files()]

# filter out all integrations as they run in separate suit
files = [fn for fn in files if not fn.startswith("tests/integrations")]
if not files:
logging.debug("Only integrations was changed so not reason for deep testing...")
return ""
return _return_empty
# filter only docs files
files_ = [fn for fn in files if fn.startswith("docs")]
if len(files) == len(files_):
files_docs = [fn for fn in files if fn.startswith("docs")]
if len(files) == len(files_docs):
logging.debug("Only docs was changed so not reason for deep testing...")
return ""
return _return_empty
# files in requirements folder
files_req = [fn for fn in files if fn.startswith("requirements")]
req_domains = [fn.split("/")[1] for fn in files_req]
# cleaning up determining domains
req_domains = [req.replace(".txt", "").replace("_test", "") for req in req_domains if not req.endswith("_")]
# if you touch base, you need to run everything
if "base" in req_domains:
return _return_all

# filter only package files and skip inits
_is_in_test = lambda fn: fn.startswith("tests")
_filter_pkg = lambda fn: _is_in_test(fn) or (fn.startswith("src/torchmetrics") and "__init__.py" not in fn)
files_pkg = [fn for fn in files if _filter_pkg(fn)]
if not files_pkg:
return "unittests"
return _return_all

# parse domains
def _crop_path(fname: str, paths: list[str]) -> str:
Expand All @@ -151,15 +166,18 @@ def _crop_path(fname: str, paths: list[str]) -> str:
tm_modules = [md for md in tm_modules if md not in general_sub_pkgs]
if len(files_pkg) > len(tm_modules):
logging.debug("Some more files was changed -> rather test everything...")
return "unittests"
# keep only unique
if as_list:
return list(tm_modules)
tm_modules = [f"unittests/{md}" for md in set(tm_modules)]
not_exists = [p for p in tm_modules if os.path.exists(p)]
return _return_all

# compose the final list with requirements and touched modules
test_modules = set(tm_modules + list(req_domains))
if as_list: # keep only unique
return list(test_modules)

test_modules = [f"unittests/{md}" for md in set(test_modules)]
not_exists = [p for p in test_modules if os.path.exists(p)]
if not_exists:
raise ValueError(f"Missing following paths: {not_exists}")
return " ".join(tm_modules)
return " ".join(test_modules)


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/_focus-diff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
set -e
echo $PR_NUMBER
pip install -q -U packaging fire pyGithub pyopenssl
# python .github/assistant.py changed-domains $PR_NUMBER
echo "focus=$(python .github/assistant.py changed-domains $PR_NUMBER)" >> $GITHUB_OUTPUT
- run: echo "${{ steps.diff-domains.outputs.focus }}"

0 comments on commit d9697b6

Please sign in to comment.