Skip to content

Commit

Permalink
Fix fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
ElePT committed May 29, 2024
1 parent b77b269 commit fef1982
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 34 deletions.
6 changes: 3 additions & 3 deletions rustworkx-core/src/dag_algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use petgraph::algo;
use petgraph::data::DataMap;
use petgraph::visit::{
Data, EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected,
IntoNodeIdentifiers, NodeIndexable, NodeCount, Visitable,
IntoNodeIdentifiers, NodeCount, NodeIndexable, Visitable,
};
use petgraph::Directed;

Expand Down Expand Up @@ -320,7 +320,7 @@ where
#[derive(Debug)]
pub enum CollectBicolorError<E: Error> {
DAGHasCycle,
CallableError(E)
CallableError(E),
}

impl<E: Error> Error for CollectBicolorError<E> {}
Expand Down Expand Up @@ -765,4 +765,4 @@ mod test_lexicographical_topological_sort {
let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial));
assert_eq!(result, Ok(Some(vec![nodes[7]])));
}
}
}
4 changes: 2 additions & 2 deletions rustworkx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ pub mod centrality;
/// Module for coloring algorithms.
pub mod coloring;
pub mod connectivity;
/// Module for algorithms that work on DAGs.
pub mod dag_algo;
pub mod generators;
pub mod graph_ext;
pub mod line_graph;
/// Module for algorithms that work on DAGs.
pub mod dag_algo;
/// Module for maximum weight matching algorithms.
pub mod max_weight_matching;
pub mod planar;
Expand Down
65 changes: 36 additions & 29 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ use rustworkx_core::dictmap::InitWithHasher;
use super::iterators::NodeIndices;
use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph};

use rustworkx_core::dag_algo::{CollectBicolorError, lexicographical_topological_sort as core_lexico_topo_sort};
use rustworkx_core::dag_algo::lexicographical_topological_sort as core_lexico_topo_sort;
use rustworkx_core::dag_algo::longest_path as core_longest_path;
use rustworkx_core::dag_algo::collect_bicolor_runs as core_collect_bicolor_runs;
use rustworkx_core::dag_algo::{
collect_bicolor_runs as core_collect_bicolor_runs, CollectBicolorError,
};
use rustworkx_core::traversal::dfs_edges;

use pyo3::exceptions::PyValueError;
Expand Down Expand Up @@ -606,11 +608,13 @@ pub fn collect_runs(

/// Define custom error conversion logic for collect_bicolor_runs.
fn convert_error(err: CollectBicolorError<PyErr>) -> PyErr {
// Note that we cannot implement From<CollectBicolorError<PyErr>> for PyErr
// because nor PyErr nor CollectBicolorError are defined in this crate,
// so we use .map_err(convert_error) to convert a CollectBicolorError to PyErr instead.
// Note that we cannot implement From<CollectBicolorError<PyErr>> for PyErr
// because nor PyErr nor CollectBicolorError are defined in this crate,
// so we use .map_err(convert_error) to convert a CollectBicolorError to PyErr instead.
match err {
CollectBicolorError::DAGHasCycle => PyErr::new::<DAGHasCycle, _>("Sort encountered a cycle"),
CollectBicolorError::DAGHasCycle => {
PyErr::new::<DAGHasCycle, _>("Sort encountered a cycle")
}
CollectBicolorError::CallableError(err) => err,
}
}
Expand Down Expand Up @@ -645,42 +649,45 @@ pub fn collect_bicolor_runs(
filter_fn: PyObject,
color_fn: PyObject,
) -> PyResult<Vec<Vec<PyObject>>> {

let dag = &graph.graph;

// Wrap filter_fn to return Result<Option<bool>, CollectBicolorError<E>>
let filter_fn_wrapper =
|node: &PyObject| -> Result<Option<bool>, CollectBicolorError<PyErr>> {
match filter_fn.call1(py, (node,)) {
Ok(res) => res.extract(py).map_err(CollectBicolorError::CallableError),
Err(err) => Err(CollectBicolorError::CallableError(err)),
let filter_fn_wrapper = |node: &PyObject| -> Result<Option<bool>, CollectBicolorError<PyErr>> {
match filter_fn.call1(py, (node,)) {
Ok(res) => res.extract(py).map_err(CollectBicolorError::CallableError),
Err(err) => Err(CollectBicolorError::CallableError(err)),
}
};

// Wrap color_fn to return Result<Option<usize>, CollectBicolorError<E>>
let color_fn_wrapper =
|edge: &PyObject| -> Result<Option<usize>, CollectBicolorError<PyErr>> {
match color_fn.call1(py, (edge,)) {
Ok(res) => res.extract(py).map_err(CollectBicolorError::CallableError),
Err(err) => Err(CollectBicolorError::CallableError(err)),
let color_fn_wrapper = |edge: &PyObject| -> Result<Option<usize>, CollectBicolorError<PyErr>> {
match color_fn.call1(py, (edge,)) {
Ok(res) => res.extract(py).map_err(CollectBicolorError::CallableError),
Err(err) => Err(CollectBicolorError::CallableError(err)),
}
};

// Map CollectBicolorError to PyErr using custom convert_error function
let block_list =
core_collect_bicolor_runs::<&StablePyGraph<Directed>, _, _, (), PyErr>(
dag,
filter_fn_wrapper,
color_fn_wrapper
).map_err(convert_error)?;
let block_list = core_collect_bicolor_runs::<&StablePyGraph<Directed>, _, _, (), PyErr>(
dag,
filter_fn_wrapper,
color_fn_wrapper,
)
.map_err(convert_error)?;

// Convert the result list from Vec<Vec<NodeId>> to Vec<Vec<PyObject>>
let py_block_list: Vec<Vec<PyObject>> = block_list.into_iter().map(|index_list| {
index_list.into_iter().map(|node_index| {
let node_weight = dag.node_weight(node_index).expect("Invalid NodeId");
node_weight.into_py(py)
}).collect()
}).collect();
let py_block_list: Vec<Vec<PyObject>> = block_list
.into_iter()
.map(|index_list| {
index_list
.into_iter()
.map(|node_index| {
let node_weight = dag.node_weight(node_index).expect("Invalid NodeId");
node_weight.into_py(py)
})
.collect()
})
.collect();

Ok(py_block_list)
}
Expand Down

0 comments on commit fef1982

Please sign in to comment.