diff --git a/releasenotes/notes/add-edge_indices_from_endpoints-method-to-graph-d58dc98719c4db39.yaml b/releasenotes/notes/add-edge_indices_from_endpoints-method-to-graph-d58dc98719c4db39.yaml new file mode 100644 index 000000000..2a7082b48 --- /dev/null +++ b/releasenotes/notes/add-edge_indices_from_endpoints-method-to-graph-d58dc98719c4db39.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added method :meth:`~rustworkx.PyGraph.edge_indices_from_endpoints` which returns the indices of all edges + between the specified endpoints. For :class:`~rustworkx.PyDiGraph` there is a corresponding method that returns the + directed edges. diff --git a/rustworkx/digraph.pyi b/rustworkx/digraph.pyi index ae470846f..9648ca62e 100644 --- a/rustworkx/digraph.pyi +++ b/rustworkx/digraph.pyi @@ -68,6 +68,7 @@ class PyDiGraph(Generic[S, T]): def copy(self) -> PyDiGraph[S, T]: ... def edge_index_map(self) -> EdgeIndexMap[T]: ... def edge_indices(self) -> EdgeIndices: ... + def edge_indices_from_endpoints(self, node_a: int, node_b: int) -> EdgeIndices: ... def edge_list(self) -> EdgeList: ... def edges(self) -> list[T]: ... def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyDiGraph[S, T]: ... diff --git a/rustworkx/graph.pyi b/rustworkx/graph.pyi index d710f6a1d..956c39c8c 100644 --- a/rustworkx/graph.pyi +++ b/rustworkx/graph.pyi @@ -66,6 +66,7 @@ class PyGraph(Generic[S, T]): def degree(self, node: int, /) -> int: ... def edge_index_map(self) -> EdgeIndexMap[T]: ... def edge_indices(self) -> EdgeIndices: ... + def edge_indices_from_endpoints(self, node_a: int, node_b: int) -> EdgeIndices: ... def edge_list(self) -> EdgeList: ... def edges(self) -> list[T]: ... def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyGraph[S, T]: ... diff --git a/src/digraph.rs b/src/digraph.rs index 2d8a97d18..1afb1ba98 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -547,6 +547,24 @@ impl PyDiGraph { } } + /// Return a list of indices of all directed edges between specified nodes + /// + /// :returns: A list of all the edge indices connecting the specified start and end node + /// :rtype: EdgeIndices + pub fn edge_indices_from_endpoints(&self, node_a: usize, node_b: usize) -> EdgeIndices { + let node_a_index = NodeIndex::new(node_a); + let node_b_index = NodeIndex::new(node_b); + + EdgeIndices { + edges: self + .graph + .edges_directed(node_a_index, petgraph::Direction::Outgoing) + .filter(|edge| edge.target() == node_b_index) + .map(|edge| edge.id().index()) + .collect(), + } + } + /// Return a list of all node data. /// /// :returns: A list of all the node data objects in the graph diff --git a/src/graph.rs b/src/graph.rs index 45d8902a7..0ad81c25b 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -422,6 +422,24 @@ impl PyGraph { } } + /// Return a list of indices of all edges between specified nodes + /// + /// :returns: A list of all the edge indices connecting the specified start and end node + /// :rtype: EdgeIndices + pub fn edge_indices_from_endpoints(&self, node_a: usize, node_b: usize) -> EdgeIndices { + let node_a_index = NodeIndex::new(node_a); + let node_b_index = NodeIndex::new(node_b); + + EdgeIndices { + edges: self + .graph + .edges_directed(node_a_index, petgraph::Direction::Outgoing) + .filter(|edge| edge.target() == node_b_index) + .map(|edge| edge.id().index()) + .collect(), + } + } + /// Return a list of all node data. /// /// :returns: A list of all the node data objects in the graph diff --git a/tests/rustworkx_tests/digraph/test_edges.py b/tests/rustworkx_tests/digraph/test_edges.py index ef0af66dd..9ec06d31a 100644 --- a/tests/rustworkx_tests/digraph/test_edges.py +++ b/tests/rustworkx_tests/digraph/test_edges.py @@ -370,6 +370,25 @@ def test_weighted_edge_list_empty(self): dag = rustworkx.PyDiGraph() self.assertEqual([], dag.weighted_edge_list()) + def test_edge_indices_from_endpoints(self): + dag = rustworkx.PyDiGraph() + dag.add_nodes_from(list(range(4))) + edge_list = [ + (0, 1, None), + (1, 2, None), + (0, 2, None), + (2, 3, None), + (0, 3, None), + (0, 2, None), + ] + dag.add_edges_from(edge_list) + indices = dag.edge_indices_from_endpoints(0, 0) + self.assertEqual(indices, []) + indices = dag.edge_indices_from_endpoints(0, 1) + self.assertEqual(indices, [0]) + indices = dag.edge_indices_from_endpoints(0, 2) + self.assertEqual(set(indices), {2, 5}) + def test_extend_from_edge_list(self): dag = rustworkx.PyDAG() edge_list = [(0, 1), (1, 2), (0, 2), (2, 3), (0, 3)] diff --git a/tests/rustworkx_tests/graph/test_edges.py b/tests/rustworkx_tests/graph/test_edges.py index 04f24af1a..628a9788b 100644 --- a/tests/rustworkx_tests/graph/test_edges.py +++ b/tests/rustworkx_tests/graph/test_edges.py @@ -331,6 +331,26 @@ def test_weighted_edge_list_empty(self): graph = rustworkx.PyGraph() self.assertEqual([], graph.weighted_edge_list()) + def test_edge_indices_from_endpoints(self): + dag = rustworkx.PyGraph() + dag.add_nodes_from(list(range(4))) + edge_list = [ + (0, 1, None), + (1, 2, None), + (0, 2, None), + (2, 3, None), + (0, 3, None), + (0, 2, None), + (2, 0, None), + ] + dag.add_edges_from(edge_list) + indices = dag.edge_indices_from_endpoints(0, 0) + self.assertEqual(indices, []) + indices = dag.edge_indices_from_endpoints(0, 1) + self.assertEqual(set(indices), {0}) + indices = dag.edge_indices_from_endpoints(0, 2) + self.assertEqual(set(indices), {2, 5, 6}) + def test_extend_from_edge_list(self): graph = rustworkx.PyGraph() edge_list = [(0, 1), (1, 2), (0, 2), (2, 3), (0, 3)]