Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
Signed-off-by: Frost Ming <me@frostming.com>
  • Loading branch information
frostming committed Aug 1, 2024
1 parent 569084f commit ac8ed26
Show file tree
Hide file tree
Showing 15 changed files with 49 additions and 124 deletions.
7 changes: 2 additions & 5 deletions examples/reporter_demo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import namedtuple

import resolvelib
from packaging.specifiers import SpecifierSet
from packaging.version import Version

import resolvelib

index = """
first 1.0.0
second == 1.0.0
Expand Down Expand Up @@ -53,9 +52,7 @@ def read_spec(lines):
candidates[latest] = set()
else:
if latest is None:
raise RuntimeError(
"Spec has dependencies before first candidate"
)
raise RuntimeError("Spec has dependencies before first candidate")
name, specifier = splitstrip(line, 2)
specifier = SpecifierSet(specifier)
candidates[latest].add(Requirement(name, specifier))
Expand Down
20 changes: 5 additions & 15 deletions examples/visualization/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def _get_subgraph(self, name, *, must_exist_already=True):
if subgraph is None:
if must_exist_already:
existing = [s.name for s in self.graph.subgraphs_iter()]
raise RuntimeError(
f"Graph for {name} not found. Existing: {existing}"
)
raise RuntimeError(f"Graph for {name} not found. Existing: {existing}")
else:
subgraph = self.graph.add_subgraph(name=c_name, label=name)

Expand Down Expand Up @@ -151,9 +149,7 @@ def adding_requirement(self, req, parent):
# We're seeing the parent candidate (which is being "evaluated"), so
# color all "active" requirements pointing to the it.
# TODO: How does this interact with revisited candidates?
for parent_req in self._active_requirements[
canonicalize_name(parent.name)
]:
for parent_req in self._active_requirements[canonicalize_name(parent.name)]:
self._ensure_edge(parent_req, to=parent, color="#80CC80")

def backtracking(self, candidate, internal=False):
Expand All @@ -175,9 +171,7 @@ def backtracking(self, candidate, internal=False):

# Trim "active" requirements to remove anything not relevant now.
for requirement in self._dependencies[candidate]:
active = self._active_requirements[
canonicalize_name(requirement.name)
]
active = self._active_requirements[canonicalize_name(requirement.name)]
active[requirement] -= 1
if not active[requirement]:
del active[requirement]
Expand All @@ -194,12 +188,8 @@ def pinning(self, candidate):
node.attr.update(color="#80CC80")

# Requirement -> Candidate edges, from this candidate.
for req in self._active_requirements[
canonicalize_name(candidate.name)
]:
self._ensure_edge(
req, to=candidate, arrowhead="vee", color="#80CC80"
)
for req in self._active_requirements[canonicalize_name(candidate.name)]:
self._ensure_edge(req, to=candidate, arrowhead="vee", color="#80CC80")

# Candidate -> Requirement edges, from this candidate.
for edge in self.graph.out_edges_iter([node_name]):
Expand Down
4 changes: 1 addition & 3 deletions examples/visualization/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def process_arguments(function, args):
to_convert, _, args = args.partition(", ")
value = int(to_convert)
elif arg_type == "requirement":
match = re.match(
r"^<Requirement\('?([\w\-\._~]+)(.*?)'?\)>(.*)", args
)
match = re.match(r"^<Requirement\('?([\w\-\._~]+)(.*?)'?\)>(.*)", args)
assert match, repr(args)
name, spec, args = match.groups()
value = Requirement(name, spec)
Expand Down
4 changes: 1 addition & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
@nox.session
def lint(session):
session.install(".[lint, test]")

session.run("black", "--check", ".")
session.run("isort", ".")
session.run("ruff", "format", "--check", ".")
session.run("ruff", "check", ".")
session.run("mypy", "src", "tests")

Expand Down
15 changes: 3 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ Homepage = "https://github.com/sarugaku/resolvelib"

[project.optional-dependencies]
lint = [
"black==23.12.1",
"ruff",
"isort",
"mypy",
"types-requests",
]
Expand Down Expand Up @@ -57,16 +55,6 @@ version = {attr = "resolvelib.__version__"}
[tool.distutils.bdist_wheel]
universal = true


[tool.black]
line-length = 79
include = '^/(docs|examples|src|tasks|tests)/.+\.py$'

[tool.isort]
profile = "black"
line_length = 79
multi_line_output = 3

[tool.towncrier]
package = 'resolvelib'
package_dir = 'src'
Expand Down Expand Up @@ -108,6 +96,9 @@ exclude = [
"*.pyi"
]

[tool.ruff.lint.isort]
known-first-party = ["resolvelib"]

[tool.mypy]
warn_unused_configs = true

Expand Down
5 changes: 2 additions & 3 deletions src/resolvelib/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from typing import Any, Protocol

class Preference(Protocol):
def __lt__(self, __other: Any) -> bool:
...
def __lt__(self, __other: Any) -> bool: ...


class AbstractProvider(Generic[RT, CT, KT]):
Expand Down Expand Up @@ -88,7 +87,7 @@ def find_matches(
identifier: KT,
requirements: Mapping[KT, Iterator[RT]],
incompatibilities: Mapping[KT, Iterator[CT]],
) -> Matches:
) -> Matches[CT]:
"""Find all possible candidates that satisfy the given constraints.
:param identifier: An identifier as returned by ``identify()``. All
Expand Down
4 changes: 1 addition & 3 deletions src/resolvelib/reporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def resolving_conflicts(
:param causes: The information on the collision that caused the backtracking.
"""

def rejecting_candidate(
self, criterion: Criterion[RT, CT], candidate: CT
) -> None:
def rejecting_candidate(self, criterion: Criterion[RT, CT], candidate: CT) -> None:
"""Called when rejecting a candidate during backtracking."""

def pinning(self, candidate: CT) -> None:
Expand Down
4 changes: 1 addition & 3 deletions src/resolvelib/resolvers/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def __init__(
self.provider = provider
self.reporter = reporter

def resolve(
self, requirements: Iterable[RT], **kwargs: Any
) -> Result[RT, CT, KT]:
def resolve(self, requirements: Iterable[RT], **kwargs: Any) -> Result[RT, CT, KT]:
"""Take a collection of constraints, spit out the resolution result.
This returns a representation of the final resolution state, with one
Expand Down
24 changes: 7 additions & 17 deletions src/resolvelib/resolvers/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,7 @@ def _is_current_pin_satisfying(
for r in criterion.iter_requirement()
)

def _get_updated_criteria(
self, candidate: CT
) -> dict[KT, Criterion[RT, CT]]:
def _get_updated_criteria(self, candidate: CT) -> dict[KT, Criterion[RT, CT]]:
criteria = self.state.criteria.copy()
for requirement in self._p.get_dependencies(candidate=candidate):
self._add_to_criteria(criteria, requirement, parent=candidate)
Expand Down Expand Up @@ -245,7 +243,7 @@ def _attempt_to_pin_criterion(self, name: KT) -> list[Criterion[RT, CT]]:

# Put newly-pinned candidate at the end. This is essential because
# backtracking looks at this mapping to get the last pin.
self.state.mapping.pop(name, None) # type: ignore[arg-type]
self.state.mapping.pop(name, None)
self.state.mapping[name] = candidate

return []
Expand Down Expand Up @@ -347,8 +345,7 @@ def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool:
# If the current dependencies and the incompatible dependencies
# are overlapping then we have found a cause of the incompatibility
current_dependencies = {
self._p.identify(d)
for d in self._p.get_dependencies(candidate)
self._p.identify(d) for d in self._p.get_dependencies(candidate)
}
if not current_dependencies.isdisjoint(incompatible_deps):
break
Expand All @@ -360,8 +357,7 @@ def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool:
break

incompatibilities_from_broken = [
(k, list(v.incompatibilities))
for k, v in broken_state.criteria.items()
(k, list(v.incompatibilities)) for k, v in broken_state.criteria.items()
]

# Also mark the newly known incompatibility.
Expand All @@ -384,13 +380,9 @@ def _extract_causes(
self, criteron: list[Criterion[RT, CT]]
) -> list[RequirementInformation[RT, CT]]:
"""Extract causes from list of criterion and deduplicate"""
return list(
{id(i): i for c in criteron for i in c.information}.values()
)
return list({id(i): i for c in criteron for i in c.information}.values())

def resolve(
self, requirements: Iterable[RT], max_rounds: int
) -> State[RT, CT, KT]:
def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT, KT]:
if self._states:
raise RuntimeError("already resolved")

Expand Down Expand Up @@ -430,9 +422,7 @@ def resolve(
return self.state

# keep track of satisfied names to calculate diff after pinning
satisfied_names = set(self.state.criteria.keys()) - set(
unsatisfied_names
)
satisfied_names = set(self.state.criteria.keys()) - set(unsatisfied_names)

# Choose the most preferred unpinned criterion to try.
name = min(unsatisfied_names, key=self._get_preference)
Expand Down
27 changes: 10 additions & 17 deletions src/resolvelib/structs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

import itertools
from abc import ABCMeta
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Callable,
Collection,
Container,
Generic,
Iterable,
Iterator,
Expand Down Expand Up @@ -150,7 +148,7 @@ def __len__(self) -> int:
return len(self._mapping) + more


class _FactoryIterableView(Generic[RT]):
class _FactoryIterableView(Iterable[RT]):
"""Wrap an iterator factory returned by `find_matches()`.
Calling `iter()` on this class would invoke the underlying iterator
Expand All @@ -174,14 +172,12 @@ def __bool__(self) -> bool:
return True

def __iter__(self) -> Iterator[RT]:
iterable = (
self._factory() if self._iterable is None else self._iterable
)
iterable = self._factory() if self._iterable is None else self._iterable
self._iterable, current = itertools.tee(iterable)
return current


class _SequenceIterableView(Generic[RT]):
class _SequenceIterableView(Iterable[RT]):
"""Wrap an iterable returned by find_matches().
This is essentially just a proxy to the underlying sequence that provides
Expand All @@ -201,19 +197,16 @@ def __iter__(self) -> Iterator[RT]:
return iter(self._sequence)


class IterableView(Container[CT], Iterator[CT], metaclass=ABCMeta):
pass


def build_iter_view(
matches: Iterable[CT] | Callable[[], Iterable[CT]]
) -> IterableView[CT]:
def build_iter_view(matches: Matches[CT]) -> Iterable[CT]:
"""Build an iterable view from the value returned by `find_matches()`."""
if callable(matches):
return _FactoryIterableView(matches) # type: ignore[return-value]
return _FactoryIterableView(matches)
if not isinstance(matches, Sequence):
matches = list(matches)
return _SequenceIterableView(matches) # type: ignore[return-value]
return _SequenceIterableView(matches)


IterableView = Iterable


class Criterion(Generic[RT, CT]):
Expand All @@ -238,7 +231,7 @@ class Criterion(Generic[RT, CT]):

def __init__(
self,
candidates: IterableView[CT],
candidates: Iterable[CT],
information: Collection[RequirementInformation[RT, CT]],
incompatibilities: Collection[CT],
) -> None:
Expand Down
9 changes: 3 additions & 6 deletions tests/functional/cocoapods/test_resolvers_cocoapods.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def __init__(self, filename):
for key, spec in case_data["requested"].items()
]
self.pinned_versions = {
entry["name"]: Version(entry["version"])
for entry in case_data["base"]
entry["name"]: Version(entry["version"]) for entry in case_data["base"]
}
self.expected_resolution = dict(_iter_resolved(case_data["resolved"]))
self.expected_conflicts = set(case_data["conflicts"])
Expand All @@ -193,8 +192,7 @@ def _iter_matches(self, name, requirements, incompatibilities):
for entry in data:
version = Version(entry["version"])
if any(
not _version_in_specset(version, r.spec)
for r in requirements[name]
not _version_in_specset(version, r.spec) for r in requirements[name]
):
continue
if version in bad_versions:
Expand Down Expand Up @@ -253,8 +251,7 @@ def _format_conflicts(exc):

def _format_resolution(result):
return {
identifier: candidate.ver
for identifier, candidate in result.mapping.items()
identifier: candidate.ver for identifier, candidate in result.mapping.items()
}


Expand Down
4 changes: 1 addition & 3 deletions tests/functional/python/py2index.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,7 @@ def iter_package_entries(self, name: str) -> Iterator[PackageEntry]:
dependencies: list[str] = data.get_all("Requires-Dist", [])
yield PackageEntry(version, dependencies)

def process_package_entry(
self, name: str, entry: PackageEntry
) -> set[str] | None:
def process_package_entry(self, name: str, entry: PackageEntry) -> set[str] | None:
more = set()
for dep in entry.dependencies:
try:
Expand Down
15 changes: 4 additions & 11 deletions tests/functional/python/test_resolvers_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@ def __init__(self, filename):
case_data = json.load(f)

index_name = os.path.normpath(
os.path.join(
filename, "..", "..", "index", case_data["index"] + ".json"
),
os.path.join(filename, "..", "..", "index", case_data["index"] + ".json"),
)
with open(index_name) as f:
self.index = json.load(f)

self.root_requirements = [
packaging.requirements.Requirement(r)
for r in case_data["requested"]
packaging.requirements.Requirement(r) for r in case_data["requested"]
]

if "resolved" in case_data:
Expand Down Expand Up @@ -182,17 +179,13 @@ def test_resolver(provider, reporter):
if provider.expected_unvisited:
visited_versions = defaultdict(set)
for visited_candidate in reporter.visited:
visited_versions[visited_candidate.name].add(
str(visited_candidate.version)
)
visited_versions[visited_candidate.name].add(str(visited_candidate.version))

for name, versions in provider.expected_unvisited.items():
if name not in visited_versions:
continue

unexpected_versions = set(versions).intersection(
visited_versions[name]
)
unexpected_versions = set(versions).intersection(visited_versions[name])
assert (
not unexpected_versions
), f"Unexpcted versions visited {name}: {', '.join(unexpected_versions)}"
Loading

0 comments on commit ac8ed26

Please sign in to comment.