diff --git a/.github/assistant.py b/.github/assistant.py index 174ed2117a4..694a6cf7e13 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -20,7 +20,6 @@ import fire from packaging.version import parse -from pkg_resources import parse_requirements _REQUEST_TIMEOUT = 10 _PATH_ROOT = os.path.dirname(os.path.dirname(__file__)) @@ -58,6 +57,9 @@ def set_min_torch_by_python(fpath: str = "requirements/base.txt") -> None: >>> AssistantCLI.set_min_torch_by_python("../requirements/base.txt") """ + # ToDo: `pkg_resources` is deprecates and shall be updated + from pkg_resources import parse_requirements + py_ver = f"{sys.version_info.major}.{sys.version_info.minor}" if py_ver not in LUT_PYTHON_TORCH: return @@ -114,9 +116,14 @@ def changed_domains( """Determine what domains were changed in particular PR.""" import github + # define some edge case return cases + _return_all = "unittests" if not as_list else ["torchmetrics"] + _return_empty = [] if as_list else "" + + # early return if no PR number if not pr: - return "unittests" - gh = github.Github() + return _return_all + gh = github.Github(login_or_token=auth_token) pr = gh.get_repo("Lightning-AI/torchmetrics").get_pull(pr) files = [f.filename for f in pr.get_files()] @@ -124,19 +131,27 @@ def changed_domains( files = [fn for fn in files if not fn.startswith("tests/integrations")] if not files: logging.debug("Only integrations was changed so not reason for deep testing...") - return "" + return _return_empty # filter only docs files - files_ = [fn for fn in files if fn.startswith("docs")] - if len(files) == len(files_): + files_docs = [fn for fn in files if fn.startswith("docs")] + if len(files) == len(files_docs): logging.debug("Only docs was changed so not reason for deep testing...") - return "" + return _return_empty + # files in requirements folder + files_req = [fn for fn in files if fn.startswith("requirements")] + req_domains = [fn.split("/")[1] for fn in files_req] + # cleaning up determining domains + req_domains = [req.replace(".txt", "").replace("_test", "") for req in req_domains if not req.endswith("_")] + # if you touch base, you need to run everything + if "base" in req_domains: + return _return_all # filter only package files and skip inits _is_in_test = lambda fn: fn.startswith("tests") _filter_pkg = lambda fn: _is_in_test(fn) or (fn.startswith("src/torchmetrics") and "__init__.py" not in fn) files_pkg = [fn for fn in files if _filter_pkg(fn)] if not files_pkg: - return "unittests" + return _return_all # parse domains def _crop_path(fname: str, paths: list[str]) -> str: @@ -151,15 +166,18 @@ def _crop_path(fname: str, paths: list[str]) -> str: tm_modules = [md for md in tm_modules if md not in general_sub_pkgs] if len(files_pkg) > len(tm_modules): logging.debug("Some more files was changed -> rather test everything...") - return "unittests" - # keep only unique - if as_list: - return list(tm_modules) - tm_modules = [f"unittests/{md}" for md in set(tm_modules)] - not_exists = [p for p in tm_modules if os.path.exists(p)] + return _return_all + + # compose the final list with requirements and touched modules + test_modules = set(tm_modules + list(req_domains)) + if as_list: # keep only unique + return list(test_modules) + + test_modules = [f"unittests/{md}" for md in set(test_modules)] + not_exists = [p for p in test_modules if os.path.exists(p)] if not_exists: raise ValueError(f"Missing following paths: {not_exists}") - return " ".join(tm_modules) + return " ".join(test_modules) if __name__ == "__main__": diff --git a/.github/workflows/_focus-diff.yml b/.github/workflows/_focus-diff.yml index 0e68eac2406..5c875fcbfab 100644 --- a/.github/workflows/_focus-diff.yml +++ b/.github/workflows/_focus-diff.yml @@ -27,7 +27,6 @@ jobs: set -e echo $PR_NUMBER pip install -q -U packaging fire pyGithub pyopenssl - # python .github/assistant.py changed-domains $PR_NUMBER echo "focus=$(python .github/assistant.py changed-domains $PR_NUMBER)" >> $GITHUB_OUTPUT - run: echo "${{ steps.diff-domains.outputs.focus }}"