Skip to content

Commit

Permalink
Use NamedTuple for packages and columns
Browse files Browse the repository at this point in the history
  • Loading branch information
realshouzy committed Feb 2, 2024
1 parent a698284 commit e20d627
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 38 deletions.
91 changes: 53 additions & 38 deletions pip_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from typing import TYPE_CHECKING, Final, NamedTuple, TextIO

if sys.version_info >= (3, 12): # pragma: >=3.12 cover
from typing import override
from typing import Self, override
else: # pragma: <3.12 cover
from typing_extensions import override
from typing_extensions import Self, override


if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -96,20 +97,6 @@
_PIP_CMD: Final[tuple[str, ...]] = (sys.executable, "-m", "pip")


class _Column(NamedTuple):
title: str
field: str


# nicer headings for the columns in the oudated package table
_COLUMNS: Final[tuple[_Column, ...]] = (
_Column("Package", "name"),
_Column("Version", "version"),
_Column("Latest", "latest_version"),
_Column("Type", "latest_filetype"),
)


def _parse_args() -> tuple[argparse.Namespace, list[str]]:
parser: argparse.ArgumentParser = argparse.ArgumentParser(
description=__doc__,
Expand Down Expand Up @@ -198,7 +185,7 @@ def _filter_forwards(args: list[str], exclude: AbstractSet[str]) -> list[str]:
return result


class StdOutFilter(logging.Filter):
class _StdOutFilter(logging.Filter):
@override
def filter(self, record: logging.LogRecord) -> bool:
return record.levelno in {logging.DEBUG, logging.INFO}
Expand All @@ -212,7 +199,7 @@ def _setup_logging(*, verbose: bool) -> logging.Logger:
logger: logging.Logger = logging.getLogger(__title__)

stdout_handler: logging.StreamHandler[TextIO] = logging.StreamHandler(sys.stdout)
stdout_handler.addFilter(StdOutFilter())
stdout_handler.addFilter(_StdOutFilter())
stdout_handler.setFormatter(logging.Formatter(format_))
stdout_handler.setLevel(logging.DEBUG)

Expand All @@ -226,7 +213,7 @@ def _setup_logging(*, verbose: bool) -> logging.Logger:
return logger


class InteractiveAsker:
class _InteractiveAsker:
def __init__(self) -> None:
self.cached_answer: str | None = None
self.last_answer: str | None = None
Expand All @@ -252,11 +239,27 @@ def ask(self, prompt: str) -> str:
return answer


_ask_to_install: partial[str] = partial(InteractiveAsker().ask, prompt="Upgrade now?")
_ask_to_install: partial[str] = partial(_InteractiveAsker().ask, prompt="Upgrade now?")


class _Package(NamedTuple):
name: str
version: str
latest_version: str
latest_filetype: str

@classmethod
def from_dct(cls, dct: dict[str, str]) -> Self:
return cls(
dct.get("name", "Unknown"),
dct.get("version", "Unknown"),
dct.get("latest_version", "Unknown"),
dct.get("latest_filetype", "Unknown"),
)


def update_packages(
packages: list[dict[str, str]],
packages: list[_Package],
forwarded: list[str],
*,
continue_on_fail: bool,
Expand All @@ -267,15 +270,15 @@ def update_packages(
if freeze_outdated_packages:
with open("requirements.txt", "w", encoding="utf-8") as f:
for pkg in packages:
f.write(f"{pkg['name']}=={pkg['version']}\n")
f.write(f"{pkg.name}=={pkg.version}\n")

if not continue_on_fail:
upgrade_cmd.extend(pkg["name"] for pkg in packages)
upgrade_cmd.extend(pkg.name for pkg in packages)
subprocess.call(upgrade_cmd, stdout=sys.stdout, stderr=sys.stderr) # nosec
else:
for pkg in packages:
subprocess.call(
[*upgrade_cmd, pkg["name"]],
[*upgrade_cmd, pkg.name],
stdout=sys.stdout,
stderr=sys.stderr,
) # nosec
Expand All @@ -284,7 +287,7 @@ def update_packages(
def _get_outdated_packages(
forwarded: list[str],
exclude: AbstractSet[str],
) -> list[dict[str, str]]:
) -> list[_Package]:
command: list[str] = [
*_PIP_CMD,
"list",
Expand All @@ -294,21 +297,33 @@ def _get_outdated_packages(
*forwarded,
]
output: str = subprocess.check_output(command).decode("utf-8") # nosec
packages: list[dict[str, str]] = json.loads(output)
return (
[pkg for pkg in packages if pkg["name"] not in exclude] if exclude else packages
)
packages: list[_Package] = [_Package.from_dct(pkg) for pkg in json.loads(output)]
return [pkg for pkg in packages if pkg.name not in exclude] if exclude else packages


# Next two functions describe how to collect data for the table.
# Note how they are not concerned with columns widths.


def _extract_column(data: list[dict[str, str]], field: str, title: str) -> list[str]:
return [title, *[item[field] for item in data]]
class _Column(NamedTuple):
title: str
field: str


# nicer headings for the columns in the oudated package table
_COLUMNS: Final[tuple[_Column, ...]] = (
_Column("Package", "name"),
_Column("Version", "version"),
_Column("Latest", "latest_version"),
_Column("Type", "latest_filetype"),
)


def _extract_column(data: list[_Package], field: str, title: str) -> list[str]:
return [title, *[getattr(item, field) for item in data]]


def _extract_table(outdated: list[dict[str, str]]) -> list[list[str]]:
def _extract_table(outdated: list[_Package]) -> list[list[str]]:
return [_extract_column(outdated, field, title) for title, field in _COLUMNS]


Expand Down Expand Up @@ -340,7 +355,7 @@ def main() -> int: # noqa: C901
logger.error("--raw and --interactive cannot be used together")
return 1

outdated: list[dict[str, str]] = _get_outdated_packages(
outdated: list[_Package] = _get_outdated_packages(
list_args,
set(args.exclude),
)
Expand All @@ -365,16 +380,16 @@ def main() -> int: # noqa: C901

if args.raw:
for pkg in outdated:
logger.info("%s==%s", pkg["name"], pkg["latest_version"])
logger.info("%s==%s", pkg.name, pkg.latest_version)
return 0

selected: list[dict[str, str]] = []
selected: list[_Package] = []
for pkg in outdated:
logger.info(
"%s==%s is available (you have %s)",
pkg["name"],
pkg["latest_version"],
pkg["version"],
pkg.name,
pkg.latest_version,
pkg.version,
)
if args.interactive:
answer: str = _ask_to_install()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ ignore = [
"B026",
"S603",
"ANN101",
"ANN102",
"PTH123",
"PLR2004",
"ERA001",
Expand Down

0 comments on commit e20d627

Please sign in to comment.