Skip to content

Commit

Permalink
feat(traversal): add option to remove start node
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkolenz committed Jan 9, 2023
1 parent 642f2a6 commit 2cdb412
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions arguebuf/services/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

# https://eddmann.com/posts/depth-first-search-and-breadth-first-search-in-python/
def dfs(
start: _Node, connections: t.Callable[[Node], t.AbstractSet[_Node]]
start: _Node,
connections: t.Callable[[Node], t.AbstractSet[_Node]],
include_start: bool = True,
) -> t.List[_Node]:
# Need to use a dict since a set does not preserve order
visited: dict[_Node, None] = {}
Expand All @@ -19,11 +21,16 @@ def dfs(
visited[node] = None
stack.extend(connections(node) - visited.keys())

return list(visited.keys())
if include_start:
return list(visited.keys())

return [node for node in visited if node != start]


def bfs(
start: _Node, connections: t.Callable[[Node], t.AbstractSet[_Node]]
start: _Node,
connections: t.Callable[[Node], t.AbstractSet[_Node]],
include_start: bool = True,
) -> t.List[_Node]:
# Need to use a dict since a set does not preserve order
visited: dict[_Node, None] = {}
Expand All @@ -36,7 +43,10 @@ def bfs(
visited[node] = None
queue.extend(connections(node) - visited.keys())

return list(visited.keys())
if include_start:
return list(visited.keys())

return [node for node in visited if node != start]


def node_distance(
Expand Down

0 comments on commit 2cdb412

Please sign in to comment.