Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial argument to topological sorters #1128

Merged
merged 6 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions releasenotes/notes/initial-topo-cd502f3140500f93.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
:func:`.lexicographical_topological_sort` and :class:`.TopologicalSorter` now both accept an
``initial`` keyword argument, which can be used to limit the returned topological orderings to
be only over the nodes that are dominated by the ``initial`` set. This can provide performance
improvements by removing the need for a search over all graph nodes to determine the initial set
of nodes with zero in degree; this is particularly relevant to :class:`.TopologicalSorter`,
where the user may terminate the search after only examining part of the order.
3 changes: 3 additions & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from .visit import BFSVisitor, DFSVisitor, DijkstraVisitor
from typing import (
TypeVar,
Callable,
Iterable,
Iterator,
final,
Sequence,
Expand Down Expand Up @@ -272,6 +273,7 @@ def lexicographical_topological_sort(
key: Callable[[_S], str],
*,
reverse: bool = ...,
initial: Iterable[int] | None = ...,
) -> list[_S]: ...
def transitive_reduction(graph: PyDiGraph, /) -> tuple[PyDiGraph, dict[int, int]]: ...
def layers(
Expand All @@ -289,6 +291,7 @@ class TopologicalSorter:
check_cycle: bool,
*,
reverse: bool = ...,
initial: Iterable[int] | None = ...,
) -> None: ...
def is_active(self) -> bool: ...
def get_ready(self) -> list[int]: ...
Expand Down
50 changes: 43 additions & 7 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,24 @@ pub fn layers(
/// topological order that would have been found if all the edges in the
/// graph were reversed. This does not affect the comparisons from the
/// ``key``.
/// :param Iterable[int] initial: If given, the initial node indices to start the topological
/// ordering from. If not given, the topological ordering will certainly contain every node in
/// the graph. If given, only the ``initial`` nodes and nodes that are dominated by the
/// ``initial`` set will be in the ordering. Notably, any node that has a natural in degree of
/// zero will not be in the output ordering if ``initial`` is given and the zero-in-degree node
/// is not in it. It is a :exc:`ValueError` to give an `initial` set where the nodes have even
/// a partial topological order between themselves.
///
/// :returns: A list of node's data lexicographically topologically sorted.
/// :rtype: list
#[pyfunction]
#[pyo3(signature = (dag, /, key, *, reverse=false))]
#[pyo3(signature = (dag, /, key, *, reverse=false, initial=None))]
pub fn lexicographical_topological_sort(
py: Python,
dag: &digraph::PyDiGraph,
key: PyObject,
reverse: bool,
initial: Option<&Bound<PyAny>>,
) -> PyResult<PyObject> {
let key_callable = |a: &PyObject| -> PyResult<PyObject> {
let res = key.call1(py, (a,))?;
Expand All @@ -387,10 +395,6 @@ pub fn lexicographical_topological_sort(
// HashMap of node_index indegree
let node_count = dag.node_count();
let (in_dir, out_dir) = traversal_directions(reverse);
let mut in_degree_map: HashMap<NodeIndex, usize> = HashMap::with_capacity(node_count);
for node in dag.graph.node_indices() {
in_degree_map.insert(node, dag.graph.edges_directed(node, in_dir).count());
}

#[derive(Clone, Eq, PartialEq)]
struct State {
Expand All @@ -416,6 +420,31 @@ pub fn lexicographical_topological_sort(
Some(self.cmp(other))
}
}

let mut in_degree_map: HashMap<NodeIndex, usize> = HashMap::with_capacity(node_count);
if let Some(initial) = initial {
// In this case, we don't iterate through all the nodes in the graph, and most nodes aren't
// in `in_degree_map`; we'll fill in the relevant edge counts lazily.
for maybe_index in initial.iter()? {
let node = NodeIndex::new(maybe_index?.extract::<usize>()?);
if dag.graph.contains_node(node) {
// It's not necessarily actually zero, but we treat it as if it is. If the node is
// reachable from another we visit during the iteration, then there was a defined
// topological order between the `initial` set, and we'll throw an error.
in_degree_map.insert(node, 0);
} else {
return Err(PyValueError::new_err(format!(
"node index {} is not in this graph",
node.index()
)));
}
}
} else {
for node in dag.graph.node_indices() {
in_degree_map.insert(node, dag.graph.edges_directed(node, in_dir).count());
}
}

let mut zero_indegree = BinaryHeap::with_capacity(node_count);
for (node, degree) in in_degree_map.iter() {
if *degree == 0 {
Expand All @@ -431,16 +460,23 @@ pub fn lexicographical_topological_sort(
while let Some(State { node, .. }) = zero_indegree.pop() {
let neighbors = dag.graph.neighbors_directed(node, out_dir);
for child in neighbors {
let child_degree = in_degree_map.get_mut(&child).unwrap();
*child_degree -= 1;
let child_degree = in_degree_map
.entry(child)
.or_insert_with(|| dag.graph.edges_directed(child, in_dir).count());
if *child_degree == 0 {
return Err(PyValueError::new_err(
"at least one initial node is reachable from another",
));
} else if *child_degree == 1 {
let map_key_raw = key_callable(&dag.graph[child])?;
let map_key: String = map_key_raw.extract(py)?;
zero_indegree.push(State {
key: map_key,
node: child,
});
in_degree_map.remove(&child);
} else {
*child_degree -= 1;
}
}
out_list.push(&dag.graph[node])
Expand Down
66 changes: 58 additions & 8 deletions src/toposort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ enum NodeState {
/// ``True``, the ordering will be a reversed topological ordering; that is, a topological
/// order if all the edges had their directions flipped, such that the first nodes returned are
/// the ones that have only incoming edges in the DAG.
/// :param Iterable[int] initial: If given, the initial node indices to start the topological
/// ordering from. If not given, the topological ordering will certainly contain every node in
/// the graph. If given, only the ``initial`` nodes and nodes that are dominated by the
/// ``initial`` set will be in the ordering. Notably, the first return from :meth:`get_ready`
/// will be the same set of values as ``initial``, and any node that has a natural in
/// degree of zero will not be in the output ordering if ``initial`` is given and the
/// zero-in-degree node is not in it.
///
/// It is a :exc:`ValueError` to give an `initial` set where the nodes have even a partial
/// topological order between themselves, though this might not appear until some call
/// to :meth:`done`.
#[pyclass(module = "rustworkx")]
pub struct TopologicalSorter {
dag: Py<PyDiGraph>,
Expand All @@ -79,8 +90,14 @@ pub struct TopologicalSorter {
#[pymethods]
impl TopologicalSorter {
#[new]
#[pyo3(signature=(dag, /, check_cycle=true, *, reverse=false))]
fn new(py: Python, dag: Py<PyDiGraph>, check_cycle: bool, reverse: bool) -> PyResult<Self> {
#[pyo3(signature=(dag, /, check_cycle=true, *, reverse=false, initial=None))]
fn new(
py: Python,
dag: Py<PyDiGraph>,
check_cycle: bool,
reverse: bool,
initial: Option<&Bound<PyAny>>,
) -> PyResult<Self> {
{
let dag = &dag.borrow(py);
if !dag.check_cycle && check_cycle && !is_directed_acyclic_graph(dag) {
Expand All @@ -89,9 +106,32 @@ impl TopologicalSorter {
}

let (in_dir, _) = traversal_directions(reverse);
let ready_nodes = {
let mut predecessor_count = HashMap::new();
let ready_nodes = if let Some(initial) = initial {
let dag = &dag.borrow(py);
initial
.iter()?
.map(|maybe_index| {
let node = NodeIndex::new(maybe_index?.extract::<usize>()?);
// If we're using an initial set, it's possible that the user gave us an
// initial list with topological ordering defined between the nodes. With this
// online sorter we detect that with a lag (it'll happen in a later call to
// `done`), but we'll see it as an attempt to reduce a predecessor count below
// this initial zero.
predecessor_count.insert(node, 0);
dag.graph
.contains_node(node)
.then_some(node)
.ok_or_else(|| {
PyValueError::new_err(format!(
"node index {} is not in this graph",
node.index()
))
})
})
.collect::<PyResult<Vec<_>>>()?
} else {
let dag = &dag.borrow(py);

dag.graph
.node_identifiers()
.filter(|node| dag.graph.neighbors_directed(*node, in_dir).next().is_none())
Expand All @@ -101,7 +141,7 @@ impl TopologicalSorter {
Ok(TopologicalSorter {
dag,
ready_nodes,
predecessor_count: HashMap::new(),
predecessor_count,
node2state: HashMap::new(),
num_passed_out: 0,
num_finished: 0,
Expand Down Expand Up @@ -144,11 +184,15 @@ impl TopologicalSorter {
/// This method unblocks any successor of each node in *nodes* for being returned
/// in the future by a call to "get_ready".
///
/// :param list nodes: A list of node indices to marks as done.
/// :param list nodes: A list of node indices to mark as done.
///
/// :raises `ValueError`: If any node in *nodes* has already been marked as
/// processed by a previous call to this method or node has not yet been returned
/// by "get_ready".
/// :raises `ValueError`: If one of the given ``initial`` nodes is a direct successor of one
/// of the nodes given to :meth:`done`. This can only happen if the ``initial`` nodes had
/// even a partial topological ordering amongst themselves, which is not a valid
/// starting input.
fn done(&mut self, py: Python, nodes: Vec<usize>) -> PyResult<()> {
let dag = &self.dag.borrow(py);
let (in_dir, out_dir) = traversal_directions(self.reverse);
Expand Down Expand Up @@ -176,10 +220,16 @@ impl TopologicalSorter {
for succ in dag.graph.neighbors_directed(node, out_dir) {
match self.predecessor_count.entry(succ) {
Entry::Occupied(mut entry) => {
*entry.get_mut() -= 1;
if *entry.get() == 0 {
let in_degree = entry.get_mut();
if *in_degree == 0 {
return Err(PyValueError::new_err(
"at least one initial node is reachable from another",
));
} else if *in_degree == 1 {
self.ready_nodes.push(succ);
entry.remove_entry();
} else {
*in_degree -= 1;
}
}
Entry::Vacant(entry) => {
Expand Down
108 changes: 108 additions & 0 deletions tests/digraph/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,114 @@ def test_lexicographical_topo_sort_reverse(self):
rustworkx.lexicographical_topological_sort(dag, lambda x: x, reverse=False), expected
)

def test_lexicographical_topo_sort_initial(self):
dag = rustworkx.PyDiGraph()
dag.add_nodes_from(range(9))
dag.add_edges_from_no_data(
[
(0, 1),
(0, 2),
(1, 3),
(2, 4),
(3, 4),
(4, 5),
(5, 6),
(4, 7),
(6, 8),
(7, 8),
]
)
# Last three nodes, nothing reachable except nodes that will be returned.
self.assertEqual(
rustworkx.lexicographical_topological_sort(dag, str, initial=[6, 7]),
[6, 7, 8],
)
# Setting `initial` to the set of root nodes should return the same as not setting it.
self.assertEqual(
rustworkx.lexicographical_topological_sort(dag, str, initial=[0]),
rustworkx.lexicographical_topological_sort(dag, str),
)
# Node 8 is reachable from 7, but isn't dominated by it, so shouldn't be returned.
self.assertEqual(
rustworkx.lexicographical_topological_sort(dag, str, initial=[7]),
[7],
)

# Putting the `initial` in unsorted order should not affect the return order.
dag = rustworkx.PyDiGraph()
# Deliberately break id:weight correspondence.
dag.add_nodes_from(range(5)[::-1])
self.assertEqual(
rustworkx.lexicographical_topological_sort(dag, str, initial=[2, 4, 3, 0, 1]),
rustworkx.lexicographical_topological_sort(dag, str),
)

def test_lexicographical_topo_sort_initial_reverse(self):
dag = rustworkx.PyDiGraph()
dag.add_nodes_from(range(9))
dag.add_edges_from_no_data(
[
(0, 1),
(0, 2),
(1, 3),
(2, 4),
(3, 4),
(4, 5),
(5, 6),
(4, 7),
(6, 8),
(7, 8),
]
)
# Last three nodes, nothing reachable except nodes that will be returned.
self.assertEqual(
rustworkx.lexicographical_topological_sort(dag, str, reverse=True, initial=[1, 2]),
[1, 2, 0],
)
# Setting `initial` to the set of root nodes should return the same as not setting it.
self.assertEqual(
rustworkx.lexicographical_topological_sort(dag, str, reverse=True, initial=[8]),
rustworkx.lexicographical_topological_sort(dag, str, reverse=True),
)
# Node 0 is reachable from 1, but isn't dominated by it, so shouldn't be returned.
self.assertEqual(
rustworkx.lexicographical_topological_sort(dag, str, reverse=True, initial=[1]),
[1],
)

# Putting the `initial` in unsorted order should not affect the return order.
dag = rustworkx.PyDiGraph()
# Deliberately break id:weight correspondence.
dag.add_nodes_from(range(5)[::-1])
self.assertEqual(
rustworkx.lexicographical_topological_sort(
dag, str, reverse=True, initial=[2, 4, 3, 0, 1]
),
rustworkx.lexicographical_topological_sort(dag, str, reverse=True),
)

def test_lexicographical_topo_sort_initial_natural_zero(self):
dag = rustworkx.PyDiGraph()
dag.add_nodes_from(range(5))
# There's no edges in this graph, so a natural topological ordering allows everything in the
# first pass. If `initial` is given, though, the loose zero-degree nodes are not dominated
# by the givens, so should not be returned.
self.assertEqual(
rustworkx.lexicographical_topological_sort(dag, key=str, initial=[0, 3]),
[0, 3],
)
self.assertEqual(
rustworkx.lexicographical_topological_sort(dag, key=str, reverse=True, initial=[0, 3]),
[0, 3],
)

def test_lexicographical_topo_sort_initial_invalid(self):
dag = rustworkx.generators.directed_path_graph(5)
with self.assertRaisesRegex(ValueError, "initial node is reachable from another"):
rustworkx.lexicographical_topological_sort(dag, str, initial=[0, 1])
with self.assertRaisesRegex(ValueError, "initial node is reachable from another"):
rustworkx.lexicographical_topological_sort(dag, str, reverse=True, initial=[3, 4])

def test_lexicographical_topo_sort_qiskit(self):
dag = rustworkx.PyDAG()
# inputs
Expand Down
Loading
Loading