diff --git a/src/poetry/puzzle/solver.py b/src/poetry/puzzle/solver.py index 32bd327a7be..1fe47fe4350 100644 --- a/src/poetry/puzzle/solver.py +++ b/src/poetry/puzzle/solver.py @@ -1,6 +1,5 @@ from __future__ import annotations -import enum import time from collections import defaultdict @@ -140,11 +139,9 @@ def _solve(self, use_latest: list[str] = None) -> tuple[list[Package], list[int] except SolveFailure as e: raise SolverProblemError(e) - # NOTE passing explicit empty array for seen to reset between invocations during - # update + install cycle results = dict( depth_first_search( - PackageNode(self._package, packages, seen=[]), aggregate_package_nodes + PackageNode(self._package, packages), aggregate_package_nodes ) ) @@ -194,27 +191,19 @@ def __str__(self) -> str: return str(self.id) -class VisitedState(enum.Enum): - Unvisited = 0 - PartiallyVisited = 1 - Visited = 2 - - def depth_first_search( source: PackageNode, aggregator: Callable ) -> list[tuple[Package, int]]: back_edges: dict[DFSNodeID, list[PackageNode]] = defaultdict(list) - visited: dict[DFSNodeID, VisitedState] = {} + visited: set[DFSNodeID] = set() topo_sorted_nodes: list[PackageNode] = [] dfs_visit(source, back_edges, visited, topo_sorted_nodes) # Combine the nodes by name combined_nodes = defaultdict(list) - name_children = defaultdict(list) for node in topo_sorted_nodes: node.visit(back_edges[node.id]) - name_children[node.name].extend(node.reachable()) combined_nodes[node.name].append(node) combined_topo_sorted_nodes = [ @@ -223,34 +212,23 @@ def depth_first_search( if node.name in combined_nodes ] - return [ - aggregator(nodes, name_children[nodes[0].name]) - for nodes in combined_topo_sorted_nodes - ] + return [aggregator(nodes) for nodes in combined_topo_sorted_nodes] def dfs_visit( node: PackageNode, back_edges: dict[DFSNodeID, list[PackageNode]], - visited: dict[DFSNodeID, VisitedState], + visited: set[DFSNodeID], sorted_nodes: list[PackageNode], -) -> bool: - if visited.get(node.id, VisitedState.Unvisited) == VisitedState.Visited: - return True - if visited.get(node.id, VisitedState.Unvisited) == VisitedState.PartiallyVisited: - # We have a circular dependency. - # Since the dependencies are resolved we can - # simply skip it because we already have it - return True - - visited[node.id] = VisitedState.PartiallyVisited +) -> None: + if node.id in visited: + return + visited.add(node.id) + for neighbor in node.reachable(): back_edges[neighbor.id].append(node) - if not dfs_visit(neighbor, back_edges, visited, sorted_nodes): - return False - visited[node.id] = VisitedState.Visited + dfs_visit(neighbor, back_edges, visited, sorted_nodes) sorted_nodes.insert(0, node) - return True class PackageNode(DFSNode): @@ -258,7 +236,6 @@ def __init__( self, package: Package, packages: list[Package], - seen: list[Package], previous: PackageNode | None = None, previous_dep: None | ( @@ -279,7 +256,6 @@ def __init__( ) -> None: self.package = package self.packages = packages - self.seen = seen self.previous = previous self.previous_dep = previous_dep @@ -306,12 +282,6 @@ def __init__( def reachable(self) -> list[PackageNode]: children: list[PackageNode] = [] - # skip already traversed packages - if self.package in self.seen: - return [] - else: - self.seen.append(self.package) - if ( self.dep and self.previous_dep @@ -348,7 +318,6 @@ def reachable(self) -> list[PackageNode]: PackageNode( pkg, self.packages, - self.seen, self, dependency, self.dep or dependency, @@ -369,19 +338,15 @@ def visit(self, parents: list[PackageNode]) -> None: ) -def aggregate_package_nodes( - nodes: list[PackageNode], children: list[PackageNode] -) -> tuple[Package, int]: +def aggregate_package_nodes(nodes: list[PackageNode]) -> tuple[Package, int]: package = nodes[0].package depth = max(node.depth for node in nodes) groups: list[str] = [] for node in nodes: groups.extend(node.groups) - category = ( - "main" if any("default" in node.groups for node in children + nodes) else "dev" - ) - optional = all(node.optional for node in children + nodes) + category = "main" if any("default" in node.groups for node in nodes) else "dev" + optional = all(node.optional for node in nodes) for node in nodes: node.depth = depth node.category = category