Skip to content

Commit

Permalink
pass total cost into successors
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhuang committed Oct 11, 2023
1 parent 9668da5 commit 4bc2fcd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
32 changes: 14 additions & 18 deletions src/directed/dijkstra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::reverse_path;
use crate::FxIndexMap;
use indexmap::map::Entry::{Occupied, Vacant};
use num_traits::Zero;
use rustc_hash::FxHashMap;
use rustc_hash::{FxHashMap, FxHashSet};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
Expand Down Expand Up @@ -320,7 +320,7 @@ impl<K: Ord> Ord for SmallestHolder<K> {
/// Struct returned by [`dijkstra_reach`](crate::directed::dijkstra::dijkstra_reach).
pub struct DijkstraReachable<N, C, FN> {
to_see: BinaryHeap<SmallestHolder<C>>,
to_see_counts: FxHashMap<usize, usize>,
seen: FxHashSet<usize>,
parents: FxIndexMap<N, (usize, C)>,
total_costs: FxHashMap<N, C>,
successors: FN,
Expand All @@ -341,26 +341,27 @@ pub struct DijkstraReachableItem<N, C> {
impl<N, C, FN, IN> Iterator for DijkstraReachable<N, C, FN>
where
N: Eq + Hash + Clone,
C: Zero + Ord + Copy,
FN: FnMut(&N) -> IN,
C: Zero + Ord + Copy + Hash,
FN: FnMut(&N, C) -> IN,
IN: IntoIterator<Item = (N, C)>,
{
type Item = DijkstraReachableItem<N, C>;

fn next(&mut self) -> Option<Self::Item> {
while let Some(SmallestHolder { cost, index }) = self.to_see.pop() {
if !self.seen.insert(index) {
continue;
}
let item;
let count = self.to_see_counts.get_mut(&index).unwrap();
*count -= 1;
let count = *count;
let successors = {
let (node, (parent_index, _)) = self.parents.get_index(index).unwrap();
let total_cost = self.total_costs[node];
item = Some(DijkstraReachableItem {
node: node.clone(),
parent: self.parents.get_index(*parent_index).map(|x| x.0.clone()),
total_cost: self.total_costs[node],
total_cost,
});
(self.successors)(node)
(self.successors)(node, total_cost)
};
for (successor, move_cost) in successors {
let new_cost = cost + move_cost;
Expand All @@ -386,12 +387,8 @@ where
cost: new_cost,
index: n,
});
*self.to_see_counts.entry(n).or_insert(0) += 1;
}
if count == 0 {
self.to_see_counts.remove(&index);
return item;
}
return item;
}

None
Expand All @@ -404,7 +401,7 @@ pub fn dijkstra_reach<N, C, FN, IN>(start: &N, successors: FN) -> DijkstraReacha
where
N: Eq + Hash + Clone,
C: Zero + Ord + Copy,
FN: FnMut(&N) -> IN,
FN: FnMut(&N, C) -> IN,
IN: IntoIterator<Item = (N, C)>,
{
let mut to_see = BinaryHeap::new();
Expand All @@ -419,12 +416,11 @@ where
let mut total_costs = FxHashMap::default();
total_costs.insert(start.clone(), Zero::zero());

let mut to_see_counts = FxHashMap::default();
to_see_counts.insert(0, 1);
let seen = FxHashSet::default();

DijkstraReachable {
to_see,
to_see_counts,
seen,
parents,
total_costs,
successors,
Expand Down
16 changes: 13 additions & 3 deletions tests/dijkstra-all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn partial_paths() {

#[test]
fn dijkstra_reach_numbers() {
let reach = dijkstra_reach(&0, |prev| vec![(prev + 1, 1), (prev * 2, *prev)])
let reach = dijkstra_reach(&0, |prev, _| vec![(prev + 1, 1), (prev * 2, *prev)])
.take_while(|x| x.total_cost < 100)
.collect_vec();
// the total cost should equal to the node's value, since the starting node is 0 and the cost to reach a successor node is equal to the increase in the node's value
Expand All @@ -132,7 +132,13 @@ fn dijkstra_reach_graph() {
graph.insert("B", vec![("C", 2)]);
graph.insert("C", vec![]);

let reach = dijkstra_reach(&"A", |prev| graph[prev].clone()).collect_vec();
let mut costs = HashMap::new();

let reach = dijkstra_reach(&"A", |prev, cost| {
costs.insert(*prev, cost);
graph[prev].clone()
})
.collect_vec();

// need to make sure that a node won't be returned twice when a better path is found after the first candidate
assert!(
Expand All @@ -154,5 +160,9 @@ fn dijkstra_reach_graph() {
total_cost: 4,
},
]
)
);

for item in reach {
assert!(item.total_cost == costs[item.node]);
}
}

0 comments on commit 4bc2fcd

Please sign in to comment.