Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to return edge indices from endpoints #1055

Merged
merged 8 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rustworkx/digraph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class PyDiGraph(Generic[S, T]):
def copy(self) -> PyDiGraph[S, T]: ...
def edge_index_map(self) -> EdgeIndexMap[T]: ...
def edge_indices(self) -> EdgeIndices: ...
def edge_indices_from_endpoints(self, int, int) -> EdgeIndices: ...
def edge_list(self) -> EdgeList: ...
def edges(self) -> list[T]: ...
def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyDiGraph[S, T]: ...
Expand Down
1 change: 1 addition & 0 deletions rustworkx/graph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class PyGraph(Generic[S, T]):
def degree(self, node: int, /) -> int: ...
def edge_index_map(self) -> EdgeIndexMap[T]: ...
def edge_indices(self) -> EdgeIndices: ...
def edge_indices_from_endpoints(self, int, int) -> EdgeIndices: ...
def edge_list(self) -> EdgeList: ...
def edges(self) -> list[T]: ...
def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyGraph[S, T]: ...
Expand Down
19 changes: 19 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,25 @@ impl PyDiGraph {
}
}

/// Return a list of indices of all directed edges between specified nodes
///
/// :returns: A list of all the edge indices connecting the specified start and end node
/// :rtype: EdgeIndices
#[pyo3(text_signature = "(self)")]
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
pub fn edge_indices_from_endpoints(&self, node_a: usize, node_b: usize) -> EdgeIndices {
let node_a_index = NodeIndex::new(node_a);
let node_b_index = NodeIndex::new(node_b);

EdgeIndices {
edges: self
.graph
.edges_directed(node_a_index, petgraph::Direction::Outgoing)
.filter(|edge| edge.target() == node_b_index)
.map(|edge| edge.id().index())
.collect(),
}
}

/// Return a list of all node data.
///
/// :returns: A list of all the node data objects in the graph
Expand Down
19 changes: 19 additions & 0 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,25 @@ impl PyGraph {
}
}

/// Return a list of indices of all edges between specified nodes
///
/// :returns: A list of all the edge indices connecting the specified start and end node
/// :rtype: EdgeIndices
#[pyo3(text_signature = "(self)")]
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
pub fn edge_indices_from_endpoints(&self, node_a: usize, node_b: usize) -> EdgeIndices {
let node_a_index = NodeIndex::new(node_a);
let node_b_index = NodeIndex::new(node_b);

EdgeIndices {
edges: self
.graph
.edges_directed(node_a_index, petgraph::Direction::Outgoing)
.filter(|edge| edge.target() == node_b_index)
.map(|edge| edge.id().index())
.collect(),
}
}

/// Return a list of all node data.
///
/// :returns: A list of all the node data objects in the graph
Expand Down
21 changes: 21 additions & 0 deletions tests/rustworkx_tests/digraph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,27 @@ def test_weighted_edge_list_empty(self):
dag = rustworkx.PyDiGraph()
self.assertEqual([], dag.weighted_edge_list())

def test_edge_indices_from_endpoints(self):
dag = rustworkx.PyDiGraph()
dag.add_nodes_from(list(range(4)))
edge_list = [
(0, 1, None),
(1, 2, None),
(0, 2, None),
(2, 3, None),
(0, 3, None),
(0, 2, None),
]
dag.add_edges_from(edge_list)
indices = dag.edge_indices_from_endpoints(0, 0)
self.assertEqual(indices, [])
indices = dag.edge_indices_from_endpoints(0, 1)
self.assertEqual(indices, [0])
indices = dag.edge_indices_from_endpoints(0, 2)
self.assertEqual(set(indices), {2, 5})



def test_extend_from_edge_list(self):
dag = rustworkx.PyDAG()
edge_list = [(0, 1), (1, 2), (0, 2), (2, 3), (0, 3)]
Expand Down
20 changes: 20 additions & 0 deletions tests/rustworkx_tests/graph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,26 @@ def test_weighted_edge_list_empty(self):
graph = rustworkx.PyGraph()
self.assertEqual([], graph.weighted_edge_list())

def test_edge_indices_from_endpoints(self):
dag = rustworkx.PyGraph()
dag.add_nodes_from(list(range(4)))
edge_list = [
(0, 1, None),
(1, 2, None),
(0, 2, None),
(2, 3, None),
(0, 3, None),
(0, 2, None),
(2, 0, None),
]
dag.add_edges_from(edge_list)
indices = dag.edge_indices_from_endpoints(0, 0)
self.assertEqual(indices, [])
indices = dag.edge_indices_from_endpoints(0, 1)
self.assertEqual(set(indices), {0})
indices = dag.edge_indices_from_endpoints(0, 2)
self.assertEqual(set(indices), {2, 5, 6})

def test_extend_from_edge_list(self):
graph = rustworkx.PyGraph()
edge_list = [(0, 1), (1, 2), (0, 2), (2, 3), (0, 3)]
Expand Down
Loading