From 22ad54178a38ea7163423ffe522599b9e94617c2 Mon Sep 17 00:00:00 2001 From: Lucas Van Mol Date: Mon, 30 Oct 2023 00:48:55 +0100 Subject: [PATCH 1/8] Add graph_all_shortest_paths() --- .../src/shortest_path/all_shortest_paths.rs | 205 ++++++++++++++++++ rustworkx-core/src/shortest_path/mod.rs | 2 + src/lib.rs | 2 + src/shortest_path/mod.rs | 49 ++++- .../graph/test_all_shortest_paths.py | 78 +++++++ 5 files changed, 335 insertions(+), 1 deletion(-) create mode 100644 rustworkx-core/src/shortest_path/all_shortest_paths.rs create mode 100644 tests/rustworkx_tests/graph/test_all_shortest_paths.py diff --git a/rustworkx-core/src/shortest_path/all_shortest_paths.rs b/rustworkx-core/src/shortest_path/all_shortest_paths.rs new file mode 100644 index 000000000..e73284d3f --- /dev/null +++ b/rustworkx-core/src/shortest_path/all_shortest_paths.rs @@ -0,0 +1,205 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +// This module was originally copied and forked from the upstream petgraph +// repository, specifically: +// https://github.com/petgraph/petgraph/blob/0.5.1/src/dijkstra.rs +// this was necessary to modify the error handling to allow python callables +// to be use for the input functions for edge_cost and return any exceptions +// raised in Python instead of panicking + +use std::hash::Hash; + +use ahash::{HashSet, HashSetExt}; +use petgraph::algo::Measure; +use petgraph::visit::{EdgeRef, IntoEdgesDirected, NodeIndexable, Visitable}; +use petgraph::Direction::Incoming; + +use super::dijkstra; +use crate::dictmap::*; + +/// Dijkstra-based all shortest paths algorithm. +/// +/// Compute every single shortest path from `start` to `goal`. +/// +/// The graph should be [`Visitable`] and implement [`IntoEdges`]. The function +/// `edge_cost` should return the cost for a particular edge, which is used +/// to compute path costs. Edge costs must be non-negative. +/// +/// +/// Returns a [`HashSet`] which contains all possible shortest paths. Each path +/// is a Vec of node indices of the path starting with `start` and ending `goal`. +/// # Example +/// ```rust +/// use rustworkx_core::petgraph::Graph; +/// use rustworkx_core::petgraph::prelude::*; +/// use rustworkx_core::dictmap::DictMap; +/// use rustworkx_core::shortest_path::all_shortest_paths; +/// use rustworkx_core::Result; +/// use ahash::HashSet; +/// +/// let mut graph : Graph<(), (), Directed>= Graph::new(); +/// let a = graph.add_node(()); // node with no weight +/// let b = graph.add_node(()); +/// let c = graph.add_node(()); +/// let d = graph.add_node(()); +/// let e = graph.add_node(()); +/// let f = graph.add_node(()); +/// let g = graph.add_node(()); +/// // z will be in another connected component +/// let z = graph.add_node(()); +/// +/// graph.extend_with_edges(&[ +/// (a, b), +/// (a, c), +/// (b, d), +/// (b, f), +/// (c, d), +/// (d, e), +/// (f, e), +/// (e, g) +/// ]); +/// // a ----> b ----> f +/// // | | | +/// // v v v +/// // c ----> d ----> e ----> g +/// +/// let expected_res: HashSet>= [ +/// vec![a, b, f, e, g], +/// vec![a, b, d, e, g], +/// vec![a, c, d, e, g] +/// ].into_iter().collect(); +/// let res: Result>> = all_shortest_paths( +/// &graph, a, g, |_| Ok(1) +/// ); +/// assert_eq!(res.unwrap(), expected_res) +/// ``` +pub fn all_shortest_paths( + graph: G, + start: G::NodeId, + goal: G::NodeId, + mut edge_cost: F, +) -> Result>, E> +where + G: IntoEdgesDirected + Visitable + NodeIndexable, + G::NodeId: Eq + Hash, + F: FnMut(G::EdgeRef) -> Result, + K: Measure + Copy, +{ + let scores: DictMap = dijkstra(&graph, start, Some(goal), &mut edge_cost, None)?; + if !scores.contains_key(&goal) { + return Ok(HashSet::default()); + } + let mut paths = HashSet::new(); + let mut queue = vec![(goal, vec![goal])]; + while let Some((curr, curr_path)) = queue.pop() { + let curr_dist = *scores.get(&curr).unwrap(); + for edge in graph.edges_directed(curr, Incoming) { + let next_dist = match scores.get(&edge.source()) { + Some(x) => *x, + None => continue, + }; + if curr_dist == next_dist + edge_cost(edge)? { + let mut new_path = curr_path.clone(); + new_path.insert(0, edge.source()); + if edge.source() == start { + paths.insert(new_path.clone()); + continue; + } + queue.push((edge.source(), new_path)); + } + } + } + return Ok(paths); +} + +#[cfg(test)] +mod tests { + use crate::shortest_path::all_shortest_paths; + use crate::Result; + use ahash::HashSet; + use petgraph::prelude::*; + use petgraph::Graph; + + #[test] + fn test_all_shortest_paths() { + let mut g = Graph::new_undirected(); + let a = g.add_node("A"); + let b = g.add_node("B"); + let c = g.add_node("C"); + let d = g.add_node("D"); + let e = g.add_node("E"); + let f = g.add_node("F"); + g.add_edge(a, b, 7); + g.add_edge(c, a, 9); + g.add_edge(a, d, 11); + g.add_edge(b, c, 10); + g.add_edge(d, c, 2); + g.add_edge(d, e, 9); + g.add_edge(b, f, 15); + g.add_edge(c, f, 11); + g.add_edge(e, f, 6); + + let start = a; + let goal = e; + let paths: Result>> = + all_shortest_paths(&g, start, goal, |e| Ok(*e.weight())); + + // a --> d --> e (11 + 9) + // a --> c --> d --> e (9 + 2 + 9) + let expected_paths: HashSet> = + [vec![a, d, e], vec![a, c, d, e]].into_iter().collect(); + assert_eq!(paths.unwrap(), expected_paths); + } + + #[test] + fn test_all_paths_no_path() { + let mut g: Graph<&str, (), Undirected> = Graph::new_undirected(); + let a = g.add_node("A"); + let b = g.add_node("B"); + + let start = a; + let goal = b; + let paths: Result>> = all_shortest_paths(&g, start, goal, |_| Ok(1)); + + let expected_paths: HashSet> = HashSet::default(); + assert_eq!(paths.unwrap(), expected_paths); + } + + #[test] + fn test_all_shortest_paths_nearly_fully_connected() { + let mut g = Graph::new_undirected(); + let num_nodes = 100; + let nodes: Vec = (0..num_nodes).map(|_| g.add_node(1)).collect(); + for n1 in nodes.iter() { + for n2 in nodes.iter() { + if n1 != n2 { + g.update_edge(*n1, *n2, 1); + } + } + } + let start = nodes[0]; + let goal = nodes[1]; + + let paths: Result>> = + all_shortest_paths(&g, start, goal, |e| Ok(*e.weight())); + assert_eq!(paths.unwrap().len(), 1); + + let edge = g.edges_connecting(start, goal).next().unwrap(); + g.remove_edge(edge.id()); + + let paths: Result>> = + all_shortest_paths(&g, start, goal, |e| Ok(*e.weight())); + + assert_eq!(paths.unwrap().len(), num_nodes - 2); + } +} diff --git a/rustworkx-core/src/shortest_path/mod.rs b/rustworkx-core/src/shortest_path/mod.rs index c150f7f4b..6203d8fba 100644 --- a/rustworkx-core/src/shortest_path/mod.rs +++ b/rustworkx-core/src/shortest_path/mod.rs @@ -15,11 +15,13 @@ //! This module contains functions for various algorithms that compute the //! shortest path of a graph. +mod all_shortest_paths; mod astar; mod bellman_ford; mod dijkstra; mod k_shortest_path; +pub use all_shortest_paths::all_shortest_paths; pub use astar::astar; pub use bellman_ford::{bellman_ford, negative_cycle_finder}; pub use dijkstra::dijkstra; diff --git a/src/lib.rs b/src/lib.rs index e078c9c35..665c2f819 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -404,6 +404,8 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(digraph_all_simple_paths))?; m.add_wrapped(wrap_pyfunction!(graph_dijkstra_shortest_paths))?; m.add_wrapped(wrap_pyfunction!(digraph_dijkstra_shortest_paths))?; + m.add_wrapped(wrap_pyfunction!(graph_all_shortest_paths))?; + // m.add_wrapped(wrap_pyfunction!(digraph_all_shortest_paths))?; m.add_wrapped(wrap_pyfunction!(graph_has_path))?; m.add_wrapped(wrap_pyfunction!(digraph_has_path))?; m.add_wrapped(wrap_pyfunction!(graph_dijkstra_shortest_path_lengths))?; diff --git a/src/shortest_path/mod.rs b/src/shortest_path/mod.rs index 0ced80339..15b5601b2 100644 --- a/src/shortest_path/mod.rs +++ b/src/shortest_path/mod.rs @@ -21,6 +21,7 @@ use std::convert::TryFrom; use crate::{digraph, edge_weights_from_callable, graph, CostFn, NegativeCycle, NoPathFound}; +use ahash::HashSet; use pyo3::prelude::*; use pyo3::Python; @@ -35,7 +36,7 @@ use numpy::IntoPyArray; use rustworkx_core::dictmap::*; use rustworkx_core::shortest_path::{ - astar, bellman_ford, dijkstra, k_shortest_path, negative_cycle_finder, + all_shortest_paths, astar, bellman_ford, dijkstra, k_shortest_path, negative_cycle_finder, }; use crate::iterators::{ @@ -109,6 +110,52 @@ pub fn graph_dijkstra_shortest_paths( }) } +/// Find all shortest paths between two nodes +/// +/// This function will generate all possible shortest paths from a source node to a +/// target using Dijkstra's algorithm. +/// +/// :param PyGraph graph: +/// :param int source: The node index to find paths from +/// :param int target: A target to find paths to +/// :param weight_fn: An optional weight function for an edge. It will accept +/// a single argument, the edge's weight object and will return a float which +/// will be used to represent the weight/cost of the edge +/// :param float default_weight: If ``weight_fn`` isn't specified this optional +/// float value will be used for the weight/cost of each edge. +/// +/// :return: List of paths. Each paths are lists of node indices, +/// starting at ``source`` and ending at ``target``. +/// :rtype: list +/// :raises ValueError: when an edge weight with NaN or negative value +/// is provided. +#[pyfunction] +#[pyo3( + signature=(graph, source, target, weight_fn=None, default_weight=1.0), + text_signature = "(graph, source, target, /, weight_fn=None, default_weight=1.0)" +)] +pub fn graph_all_shortest_paths( + py: Python, + graph: &graph::PyGraph, + source: usize, + target: usize, + weight_fn: Option, + default_weight: f64, +) -> PyResult>> { + let start = NodeIndex::new(source); + let goal = NodeIndex::new(target); + + let cost_fn = CostFn::try_from((weight_fn, default_weight))?; + + let paths = (all_shortest_paths(&graph.graph, start, goal, |e| cost_fn.call(py, e.weight())) + as PyResult>>)?; + + Ok(paths + .iter() + .map(|v| v.iter().map(|v| v.index()).collect()) + .collect()) +} + /// Check if a graph has a path between source and target nodes /// /// :param PyGraph graph: diff --git a/tests/rustworkx_tests/graph/test_all_shortest_paths.py b/tests/rustworkx_tests/graph/test_all_shortest_paths.py new file mode 100644 index 000000000..50327bfa6 --- /dev/null +++ b/tests/rustworkx_tests/graph/test_all_shortest_paths.py @@ -0,0 +1,78 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import rustworkx + + +class TestDijkstraGraph(unittest.TestCase): + def setUp(self): + self.graph = rustworkx.PyGraph() + self.a = self.graph.add_node("A") + self.b = self.graph.add_node("B") + self.c = self.graph.add_node("C") + self.d = self.graph.add_node("D") + self.e = self.graph.add_node("E") + self.f = self.graph.add_node("F") + self.graph.add_edge(self.a, self.b, 7) + self.graph.add_edge(self.c, self.a, 9) + self.graph.add_edge(self.a, self.d, 14) + self.graph.add_edge(self.b, self.c, 10) + self.graph.add_edge(self.d, self.c, 2) + self.graph.add_edge(self.d, self.e, 9) + self.graph.add_edge(self.b, self.f, 15) + self.graph.add_edge(self.c, self.f, 11) + self.graph.add_edge(self.e, self.f, 6) + + def test_all_shortest_paths_single(self): + paths = rustworkx.graph_all_shortest_paths( + self.graph, self.a, self.e, lambda x: float(x) + ) + # a -> d -> e = 23 + # a -> c -> d -> e = 20 + expected = [[self.a, self.c, self.d, self.e]] + self.assertEqual(expected, paths) + + def test_all_shortest_paths(self): + self.graph.update_edge(self.a, self.d, 11) + + paths = rustworkx.graph_all_shortest_paths( + self.graph, self.a, self.e, lambda x: float(x) + ) + # a -> d -> e = 20 + # a -> c -> d -> e = 20 + expected = [[self.a, self.d, self.e], [self.a, self.c, self.d, self.e]] + self.assertEqual(len(paths), 2) + self.assertIn(expected[0], paths) + self.assertIn(expected[1], paths) + + def test_all_shortest_paths_with_no_path(self): + g = rustworkx.PyGraph() + a = g.add_node("A") + b = g.add_node("B") + paths = rustworkx.graph_all_shortest_paths( + g, a, b, lambda x: float(x)) + expected = [] + self.assertEqual(expected, paths) + + def test_all_shortest_paths_with_invalid_weights(self): + graph = rustworkx.generators.path_graph(2) + for invalid_weight in [float("nan"), -1]: + with self.subTest(invalid_weight=invalid_weight): + with self.assertRaises(ValueError): + rustworkx.graph_all_shortest_paths( + graph, + source=0, + target=1, + weight_fn=lambda _: invalid_weight, + ) From 5e5007327435b21d70695e81aef6aed88ba4aac9 Mon Sep 17 00:00:00 2001 From: Lucas Van Mol Date: Mon, 30 Oct 2023 12:06:10 +0100 Subject: [PATCH 2/8] Clippy --- rustworkx-core/src/shortest_path/all_shortest_paths.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rustworkx-core/src/shortest_path/all_shortest_paths.rs b/rustworkx-core/src/shortest_path/all_shortest_paths.rs index e73284d3f..35721f3ed 100644 --- a/rustworkx-core/src/shortest_path/all_shortest_paths.rs +++ b/rustworkx-core/src/shortest_path/all_shortest_paths.rs @@ -119,7 +119,7 @@ where } } } - return Ok(paths); + Ok(paths) } #[cfg(test)] From db10d7006168b7d5404e4bc4b293dcc7c7fbfa92 Mon Sep 17 00:00:00 2001 From: Lucas Van Mol Date: Mon, 30 Oct 2023 17:29:14 +0100 Subject: [PATCH 3/8] tox -eblack --- .../rustworkx_tests/graph/test_all_shortest_paths.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/rustworkx_tests/graph/test_all_shortest_paths.py b/tests/rustworkx_tests/graph/test_all_shortest_paths.py index 50327bfa6..049a60b0c 100644 --- a/tests/rustworkx_tests/graph/test_all_shortest_paths.py +++ b/tests/rustworkx_tests/graph/test_all_shortest_paths.py @@ -35,9 +35,7 @@ def setUp(self): self.graph.add_edge(self.e, self.f, 6) def test_all_shortest_paths_single(self): - paths = rustworkx.graph_all_shortest_paths( - self.graph, self.a, self.e, lambda x: float(x) - ) + paths = rustworkx.graph_all_shortest_paths(self.graph, self.a, self.e, lambda x: float(x)) # a -> d -> e = 23 # a -> c -> d -> e = 20 expected = [[self.a, self.c, self.d, self.e]] @@ -46,9 +44,7 @@ def test_all_shortest_paths_single(self): def test_all_shortest_paths(self): self.graph.update_edge(self.a, self.d, 11) - paths = rustworkx.graph_all_shortest_paths( - self.graph, self.a, self.e, lambda x: float(x) - ) + paths = rustworkx.graph_all_shortest_paths(self.graph, self.a, self.e, lambda x: float(x)) # a -> d -> e = 20 # a -> c -> d -> e = 20 expected = [[self.a, self.d, self.e], [self.a, self.c, self.d, self.e]] @@ -60,8 +56,7 @@ def test_all_shortest_paths_with_no_path(self): g = rustworkx.PyGraph() a = g.add_node("A") b = g.add_node("B") - paths = rustworkx.graph_all_shortest_paths( - g, a, b, lambda x: float(x)) + paths = rustworkx.graph_all_shortest_paths(g, a, b, lambda x: float(x)) expected = [] self.assertEqual(expected, paths) From 422f13be3105f0533ca9d09d3abc5bec42c25842 Mon Sep 17 00:00:00 2001 From: Lucas Van Mol Date: Tue, 31 Oct 2023 14:28:43 +0100 Subject: [PATCH 4/8] Remove unneeded clone --- rustworkx-core/src/shortest_path/all_shortest_paths.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rustworkx-core/src/shortest_path/all_shortest_paths.rs b/rustworkx-core/src/shortest_path/all_shortest_paths.rs index 35721f3ed..57f6a01ed 100644 --- a/rustworkx-core/src/shortest_path/all_shortest_paths.rs +++ b/rustworkx-core/src/shortest_path/all_shortest_paths.rs @@ -112,7 +112,7 @@ where let mut new_path = curr_path.clone(); new_path.insert(0, edge.source()); if edge.source() == start { - paths.insert(new_path.clone()); + paths.insert(new_path); continue; } queue.push((edge.source(), new_path)); From d03435c1d3bf5ff25cd323a782b9b840ae494173 Mon Sep 17 00:00:00 2001 From: Lucas Van Mol Date: Wed, 1 Nov 2023 14:01:05 +0100 Subject: [PATCH 5/8] Use vecdeque --- rustworkx-core/src/shortest_path/all_shortest_paths.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rustworkx-core/src/shortest_path/all_shortest_paths.rs b/rustworkx-core/src/shortest_path/all_shortest_paths.rs index 57f6a01ed..9f943ceb0 100644 --- a/rustworkx-core/src/shortest_path/all_shortest_paths.rs +++ b/rustworkx-core/src/shortest_path/all_shortest_paths.rs @@ -17,6 +17,7 @@ // to be use for the input functions for edge_cost and return any exceptions // raised in Python instead of panicking +use std::collections::VecDeque; use std::hash::Hash; use ahash::{HashSet, HashSetExt}; @@ -100,7 +101,8 @@ where return Ok(HashSet::default()); } let mut paths = HashSet::new(); - let mut queue = vec![(goal, vec![goal])]; + let path = VecDeque::from([goal]); + let mut queue = vec![(goal, path)]; while let Some((curr, curr_path)) = queue.pop() { let curr_dist = *scores.get(&curr).unwrap(); for edge in graph.edges_directed(curr, Incoming) { @@ -110,9 +112,9 @@ where }; if curr_dist == next_dist + edge_cost(edge)? { let mut new_path = curr_path.clone(); - new_path.insert(0, edge.source()); + new_path.push_front(edge.source()); if edge.source() == start { - paths.insert(new_path); + paths.insert(new_path.into()); continue; } queue.push((edge.source(), new_path)); From 28e469376bb376d17b80b7077d584085db99f741 Mon Sep 17 00:00:00 2001 From: Lucas Van Mol Date: Sat, 25 Nov 2023 18:21:41 +0100 Subject: [PATCH 6/8] Support 0 weight edges --- .../src/shortest_path/all_shortest_paths.rs | 92 +++++++++++++++---- src/lib.rs | 2 +- src/shortest_path/mod.rs | 59 +++++++++++- .../graph/test_all_shortest_paths.py | 21 ++++- 4 files changed, 151 insertions(+), 23 deletions(-) diff --git a/rustworkx-core/src/shortest_path/all_shortest_paths.rs b/rustworkx-core/src/shortest_path/all_shortest_paths.rs index 9f943ceb0..1c64c4548 100644 --- a/rustworkx-core/src/shortest_path/all_shortest_paths.rs +++ b/rustworkx-core/src/shortest_path/all_shortest_paths.rs @@ -20,7 +20,6 @@ use std::collections::VecDeque; use std::hash::Hash; -use ahash::{HashSet, HashSetExt}; use petgraph::algo::Measure; use petgraph::visit::{EdgeRef, IntoEdgesDirected, NodeIndexable, Visitable}; use petgraph::Direction::Incoming; @@ -32,12 +31,12 @@ use crate::dictmap::*; /// /// Compute every single shortest path from `start` to `goal`. /// -/// The graph should be [`Visitable`] and implement [`IntoEdges`]. The function +/// The graph should be [`Visitable`] and implement [`IntoEdgesDirected`]. The function /// `edge_cost` should return the cost for a particular edge, which is used /// to compute path costs. Edge costs must be non-negative. /// /// -/// Returns a [`HashSet`] which contains all possible shortest paths. Each path +/// Returns a [`Vec`] which contains all possible shortest paths. Each path /// is a Vec of node indices of the path starting with `start` and ending `goal`. /// # Example /// ```rust @@ -74,12 +73,12 @@ use crate::dictmap::*; /// // v v v /// // c ----> d ----> e ----> g /// -/// let expected_res: HashSet>= [ -/// vec![a, b, f, e, g], +/// let expected_res: Vec>= [ /// vec![a, b, d, e, g], -/// vec![a, c, d, e, g] +/// vec![a, c, d, e, g], +/// vec![a, b, f, e, g], /// ].into_iter().collect(); -/// let res: Result>> = all_shortest_paths( +/// let res: Result>> = all_shortest_paths( /// &graph, a, g, |_| Ok(1) /// ); /// assert_eq!(res.unwrap(), expected_res) @@ -89,23 +88,27 @@ pub fn all_shortest_paths( start: G::NodeId, goal: G::NodeId, mut edge_cost: F, -) -> Result>, E> +) -> Result>, E> where G: IntoEdgesDirected + Visitable + NodeIndexable, G::NodeId: Eq + Hash, F: FnMut(G::EdgeRef) -> Result, K: Measure + Copy, { - let scores: DictMap = dijkstra(&graph, start, Some(goal), &mut edge_cost, None)?; + let scores: DictMap = dijkstra(&graph, start, None, &mut edge_cost, None)?; if !scores.contains_key(&goal) { - return Ok(HashSet::default()); + return Ok(vec![]); } - let mut paths = HashSet::new(); + let mut paths = vec![]; let path = VecDeque::from([goal]); let mut queue = vec![(goal, path)]; while let Some((curr, curr_path)) = queue.pop() { let curr_dist = *scores.get(&curr).unwrap(); for edge in graph.edges_directed(curr, Incoming) { + // Only simple paths + if curr_path.contains(&edge.source()) { + continue; + } let next_dist = match scores.get(&edge.source()) { Some(x) => *x, None => continue, @@ -114,7 +117,7 @@ where let mut new_path = curr_path.clone(); new_path.push_front(edge.source()); if edge.source() == start { - paths.insert(new_path.into()); + paths.push(new_path.into()); continue; } queue.push((edge.source(), new_path)); @@ -128,7 +131,6 @@ where mod tests { use crate::shortest_path::all_shortest_paths; use crate::Result; - use ahash::HashSet; use petgraph::prelude::*; use petgraph::Graph; @@ -153,12 +155,12 @@ mod tests { let start = a; let goal = e; - let paths: Result>> = + let paths: Result>> = all_shortest_paths(&g, start, goal, |e| Ok(*e.weight())); // a --> d --> e (11 + 9) // a --> c --> d --> e (9 + 2 + 9) - let expected_paths: HashSet> = + let expected_paths: Vec> = [vec![a, d, e], vec![a, c, d, e]].into_iter().collect(); assert_eq!(paths.unwrap(), expected_paths); } @@ -171,12 +173,64 @@ mod tests { let start = a; let goal = b; - let paths: Result>> = all_shortest_paths(&g, start, goal, |_| Ok(1)); + let paths: Result>> = all_shortest_paths(&g, start, goal, |_| Ok(1)); - let expected_paths: HashSet> = HashSet::default(); + let expected_paths: Vec> = vec![]; assert_eq!(paths.unwrap(), expected_paths); } + #[test] + fn test_all_paths_0_weight() { + let mut g = Graph::new_undirected(); + let a = g.add_node("A"); + let b = g.add_node("B"); + let c = g.add_node("C"); + let d = g.add_node("D"); + let e = g.add_node("E"); + let f = g.add_node("F"); + + g.add_edge(a, b, 1); + g.add_edge(b, f, 2); + g.add_edge(a, c, 2); + g.add_edge(c, d, 1); + g.add_edge(d, e, 0); + g.add_edge(e, f, 0); + + let start = a; + let goal = f; + + let paths: Result>> = + all_shortest_paths(&g, start, goal, |e| Ok(*e.weight())); + + assert_eq!(paths.unwrap().len(), 2); + } + + #[test] + fn test_all_paths_0_weight_cycles() { + let mut g = Graph::new_undirected(); + let a = g.add_node("A"); + let b = g.add_node("B"); + let c = g.add_node("C"); + let d = g.add_node("D"); + let e = g.add_node("E"); + let f = g.add_node("F"); + + g.add_edge(a, b, 1); + g.add_edge(b, c, 0); + g.add_edge(c, f, 1); + g.add_edge(b, d, 0); + g.add_edge(d, e, 0); + g.add_edge(e, c, 0); + + let start = a; + let goal = f; + + let paths: Result>> = + dbg!(all_shortest_paths(&g, start, goal, |e| Ok(*e.weight()))); + + assert_eq!(paths.unwrap().len(), 2); + } + #[test] fn test_all_shortest_paths_nearly_fully_connected() { let mut g = Graph::new_undirected(); @@ -192,14 +246,14 @@ mod tests { let start = nodes[0]; let goal = nodes[1]; - let paths: Result>> = + let paths: Result>> = all_shortest_paths(&g, start, goal, |e| Ok(*e.weight())); assert_eq!(paths.unwrap().len(), 1); let edge = g.edges_connecting(start, goal).next().unwrap(); g.remove_edge(edge.id()); - let paths: Result>> = + let paths: Result>> = all_shortest_paths(&g, start, goal, |e| Ok(*e.weight())); assert_eq!(paths.unwrap().len(), num_nodes - 2); diff --git a/src/lib.rs b/src/lib.rs index 665c2f819..7c2d81dd9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -405,7 +405,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(graph_dijkstra_shortest_paths))?; m.add_wrapped(wrap_pyfunction!(digraph_dijkstra_shortest_paths))?; m.add_wrapped(wrap_pyfunction!(graph_all_shortest_paths))?; - // m.add_wrapped(wrap_pyfunction!(digraph_all_shortest_paths))?; + m.add_wrapped(wrap_pyfunction!(digraph_all_shortest_paths))?; m.add_wrapped(wrap_pyfunction!(graph_has_path))?; m.add_wrapped(wrap_pyfunction!(digraph_has_path))?; m.add_wrapped(wrap_pyfunction!(graph_dijkstra_shortest_path_lengths))?; diff --git a/src/shortest_path/mod.rs b/src/shortest_path/mod.rs index 15b5601b2..d0d4eaa6d 100644 --- a/src/shortest_path/mod.rs +++ b/src/shortest_path/mod.rs @@ -21,7 +21,6 @@ use std::convert::TryFrom; use crate::{digraph, edge_weights_from_callable, graph, CostFn, NegativeCycle, NoPathFound}; -use ahash::HashSet; use pyo3::prelude::*; use pyo3::Python; @@ -148,7 +147,7 @@ pub fn graph_all_shortest_paths( let cost_fn = CostFn::try_from((weight_fn, default_weight))?; let paths = (all_shortest_paths(&graph.graph, start, goal, |e| cost_fn.call(py, e.weight())) - as PyResult>>)?; + as PyResult>>)?; Ok(paths .iter() @@ -257,6 +256,62 @@ pub fn digraph_dijkstra_shortest_paths( }) } +/// Find all shortest paths between two nodes +/// +/// This function will generate all possible shortest paths from a source node to a +/// target using Dijkstra's algorithm. +/// +/// :param PyDiGraph graph: +/// :param int source: The node index to find paths from +/// :param int target: A target to find paths to +/// :param weight_fn: An optional weight function for an edge. It will accept +/// a single argument, the edge's weight object and will return a float which +/// will be used to represent the weight/cost of the edge +/// :param float default_weight: If ``weight_fn`` isn't specified this optional +/// float value will be used for the weight/cost of each edge. +/// +/// :return: List of paths. Each paths are lists of node indices, +/// starting at ``source`` and ending at ``target``. +/// :rtype: list +/// :raises ValueError: when an edge weight with NaN or negative value +/// is provided. +#[pyfunction] +#[pyo3( + signature=(graph, source, target, weight_fn=None, default_weight=1.0, as_undirected=false), + text_signature = "(graph, source, target, /, weight_fn=None, default_weight=1.0, as_undirected=False)" +)] +pub fn digraph_all_shortest_paths( + py: Python, + graph: &digraph::PyDiGraph, + source: usize, + target: usize, + weight_fn: Option, + default_weight: f64, + as_undirected: bool, +) -> PyResult>> { + let start = NodeIndex::new(source); + let goal = NodeIndex::new(target); + + let cost_fn = CostFn::try_from((weight_fn, default_weight))?; + + let paths = if as_undirected { + (all_shortest_paths( + &graph.to_undirected(py, true, None)?.graph, + start, + goal, + |e| cost_fn.call(py, e.weight()), + ) as PyResult>>)? + } else { + (all_shortest_paths(&graph.graph, start, goal, |e| cost_fn.call(py, e.weight())) + as PyResult>>)? + }; + + Ok(paths + .iter() + .map(|v| v.iter().map(|v| v.index()).collect()) + .collect()) +} + /// Check if a digraph has a path between source and target nodes /// /// :param PyDiGraph graph: diff --git a/tests/rustworkx_tests/graph/test_all_shortest_paths.py b/tests/rustworkx_tests/graph/test_all_shortest_paths.py index 049a60b0c..e86486a15 100644 --- a/tests/rustworkx_tests/graph/test_all_shortest_paths.py +++ b/tests/rustworkx_tests/graph/test_all_shortest_paths.py @@ -15,7 +15,7 @@ import rustworkx -class TestDijkstraGraph(unittest.TestCase): +class TestGraphAllShortestPaths(unittest.TestCase): def setUp(self): self.graph = rustworkx.PyGraph() self.a = self.graph.add_node("A") @@ -71,3 +71,22 @@ def test_all_shortest_paths_with_invalid_weights(self): target=1, weight_fn=lambda _: invalid_weight, ) + + def test_all_shortest_paths_graph_with_digraph_input(self): + g = rustworkx.PyDAG() + g.add_node(0) + g.add_node(1) + with self.assertRaises(TypeError): + rustworkx.graph_all_shortest_paths(g, 0, 1, lambda x: x) + + def test_all_shortest_paths_digraph(self): + g = rustworkx.PyDAG() + g.add_node(0) + g.add_node(1) + g.add_edge(0, 1, 1) + paths_directed = rustworkx.digraph_all_shortest_paths(g, 1, 0, lambda x: x) + self.assertEqual([], paths_directed) + + paths_undirected = rustworkx.digraph_all_shortest_paths(g, 1, 0, lambda x: x, as_undirected=True) + self.assertEqual([[1, 0]], paths_undirected) + From a48db3a4bcb465a22527727b441fc325e3705f23 Mon Sep 17 00:00:00 2001 From: Lucas Van Mol Date: Sat, 25 Nov 2023 18:33:04 +0100 Subject: [PATCH 7/8] Update docs --- .../notes/add-all-shortest-paths-52506ad9c5156726.yaml | 6 ++++++ rustworkx-core/src/shortest_path/all_shortest_paths.rs | 7 ------- 2 files changed, 6 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/add-all-shortest-paths-52506ad9c5156726.yaml diff --git a/releasenotes/notes/add-all-shortest-paths-52506ad9c5156726.yaml b/releasenotes/notes/add-all-shortest-paths-52506ad9c5156726.yaml new file mode 100644 index 000000000..2d54e626f --- /dev/null +++ b/releasenotes/notes/add-all-shortest-paths-52506ad9c5156726.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added new functions :func:`~rustworkx.graph_all_shortest_paths` and + :func:`~rustworkx.digraph_all_shortest_paths` that finds every + simple shortest path in a (di)graph. diff --git a/rustworkx-core/src/shortest_path/all_shortest_paths.rs b/rustworkx-core/src/shortest_path/all_shortest_paths.rs index 1c64c4548..292a48f24 100644 --- a/rustworkx-core/src/shortest_path/all_shortest_paths.rs +++ b/rustworkx-core/src/shortest_path/all_shortest_paths.rs @@ -10,13 +10,6 @@ // License for the specific language governing permissions and limitations // under the License. -// This module was originally copied and forked from the upstream petgraph -// repository, specifically: -// https://github.com/petgraph/petgraph/blob/0.5.1/src/dijkstra.rs -// this was necessary to modify the error handling to allow python callables -// to be use for the input functions for edge_cost and return any exceptions -// raised in Python instead of panicking - use std::collections::VecDeque; use std::hash::Hash; From 23d63c90517f173916482c1a65cecf75354a1218 Mon Sep 17 00:00:00 2001 From: Lucas Van Mol Date: Sat, 25 Nov 2023 18:41:25 +0100 Subject: [PATCH 8/8] fmt --- tests/rustworkx_tests/graph/test_all_shortest_paths.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/rustworkx_tests/graph/test_all_shortest_paths.py b/tests/rustworkx_tests/graph/test_all_shortest_paths.py index e86486a15..79668e1e0 100644 --- a/tests/rustworkx_tests/graph/test_all_shortest_paths.py +++ b/tests/rustworkx_tests/graph/test_all_shortest_paths.py @@ -71,7 +71,7 @@ def test_all_shortest_paths_with_invalid_weights(self): target=1, weight_fn=lambda _: invalid_weight, ) - + def test_all_shortest_paths_graph_with_digraph_input(self): g = rustworkx.PyDAG() g.add_node(0) @@ -87,6 +87,7 @@ def test_all_shortest_paths_digraph(self): paths_directed = rustworkx.digraph_all_shortest_paths(g, 1, 0, lambda x: x) self.assertEqual([], paths_directed) - paths_undirected = rustworkx.digraph_all_shortest_paths(g, 1, 0, lambda x: x, as_undirected=True) + paths_undirected = rustworkx.digraph_all_shortest_paths( + g, 1, 0, lambda x: x, as_undirected=True + ) self.assertEqual([[1, 0]], paths_undirected) -