Skip to content

Commit

Permalink
Switched order_heap to be a quaternary heap
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Mar 18, 2024
1 parent 002f9aa commit dbe576e
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 37 deletions.
20 changes: 17 additions & 3 deletions src/batsat/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2165,7 +2165,7 @@ impl VarState {
for (_, x) in self.activity.iter_mut() {
*x *= scale;
}
for (_, x) in self.order_heap_data.heap.iter_mut() {
for (_, x) in self.order_heap_data.heap_mut().iter_mut() {
*x *= scale
}
self.var_inc *= scale;
Expand Down Expand Up @@ -2410,10 +2410,24 @@ impl PartialEq for Watcher {
}
impl Eq for Watcher {}

impl<'a> VarOrder<'a> {
fn check_activity(&self, var: Var) -> f32 {
if var == Var::UNDEF {
0.0
} else {
self.activity[var]
}
}
}

impl<'a> Comparator<(Var, f32)> for VarOrder<'a> {
fn max_value(&self) -> (Var, f32) {
(Var::UNDEF, 0.0)
}

fn cmp(&self, lhs: &(Var, f32), rhs: &(Var, f32)) -> cmp::Ordering {
debug_assert_eq!(self.activity[rhs.0], rhs.1);
debug_assert_eq!(self.activity[lhs.0], lhs.1);
debug_assert_eq!(self.check_activity(rhs.0), rhs.1);
debug_assert_eq!(self.check_activity(lhs.0), lhs.1);
PartialOrd::partial_cmp(&rhs.1, &lhs.1)
.expect("NaN activity")
.then(lhs.0.cmp(&rhs.0))
Expand Down
135 changes: 101 additions & 34 deletions src/batsat/src/heap.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use crate::intmap::{AsIndex, IntMap};
use std::{cmp, ops};
use std::fmt::Debug;
use std::{cmp, mem, ops};

/// Quaternary Heap
#[derive(Debug, Clone)]
pub struct HeapData<K: AsIndex, V> {
pub(crate) heap: Vec<(K, V)>,
heap: Box<[(K, V)]>,
next_slot: usize,
indices: IntMap<K, i32>,
}

impl<K: AsIndex, V> Default for HeapData<K, V> {
fn default() -> Self {
Self {
heap: Vec::new(),
heap: Box::new([]),
next_slot: 0,
indices: IntMap::new(),
}
}
Expand All @@ -24,7 +28,7 @@ impl<K: AsIndex, V> HeapData<K, V> {
self.heap.len()
}
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
self.next_slot <= ROOT as usize
}
pub fn in_heap(&self, k: K) -> bool {
self.indices.has(k) && self.indices[k] >= 0
Expand All @@ -33,6 +37,15 @@ impl<K: AsIndex, V> HeapData<K, V> {
pub fn promote<Comp: Comparator<(K, V)>>(&mut self, comp: Comp) -> Heap<K, V, Comp> {
Heap { data: self, comp }
}

/// Raw mutable access to all the elements of the heap
pub(crate) fn heap_mut(&mut self) -> &mut [(K, V)] {
if self.next_slot == 0 {
&mut []
} else {
&mut self.heap[ROOT as usize..self.next_slot]
}
}
}

impl<K: AsIndex, V> ops::Index<usize> for HeapData<K, V> {
Expand All @@ -43,6 +56,7 @@ impl<K: AsIndex, V> ops::Index<usize> for HeapData<K, V> {
}

pub trait Comparator<T: ?Sized> {
fn max_value(&self) -> T;
fn cmp(&self, lhs: &T, rhs: &T) -> cmp::Ordering;
fn max(&self, lhs: T, rhs: T) -> T
where
Expand Down Expand Up @@ -109,11 +123,39 @@ impl<'a, K: AsIndex + 'a, V: 'a, Comp> ops::DerefMut for Heap<'a, K, V, Comp> {
}

impl<'a, K: AsIndex + 'a, V: Copy + 'a, Comp: MemoComparator<K, V>> Heap<'a, K, V, Comp> {
// ensure size is always a multiple of 4
#[cold]
#[inline(never)]
fn heap_reserve(&mut self) {
debug_assert_eq!(self.next_slot, self.data.len());
if self.next_slot == 0 {
self.next_slot = ROOT as usize;
// Enough space for the root and 4 children
self.heap = vec![self.comp.max_value(); 8].into_boxed_slice();
} else {
let new_size = self.next_slot << 2;
let mut heap = mem::replace(&mut self.heap, Box::new([])).into_vec();
heap.resize(new_size, self.comp.max_value());
self.heap = heap.into_boxed_slice();
}
}

#[inline]
fn heap_push(&mut self, k: K, v: V) -> u32 {
if self.next_slot >= self.heap.len() {
self.heap_reserve();
assert!(self.next_slot < self.heap.len());
}
let slot = self.next_slot;
self.heap[slot] = (k, v);
self.next_slot += 1;
slot as u32
}
fn percolate_up(&mut self, mut i: u32) {
let x = self.heap[i as usize];
let mut p = parent_index(i);

while i != 0 && self.comp.lt(&x, &self.heap[p as usize]) {
while i != ROOT && self.comp.lt(&x, &self.heap[p as usize]) {
self.heap[i as usize] = self.heap[p as usize];
let tmp = self.heap[p as usize];
self.indices[tmp.0] = i as i32;
Expand All @@ -124,24 +166,40 @@ impl<'a, K: AsIndex + 'a, V: Copy + 'a, Comp: MemoComparator<K, V>> Heap<'a, K,
self.indices[x.0] = i as i32;
}

#[inline]
fn bundle(&self, i: u32) -> (u32, (K, V)) {
(i, self.heap[i as usize])
}

#[inline]
fn min(&self, x: (u32, (K, V)), y: (u32, (K, V))) -> (u32, (K, V)) {
if self.comp.lt(&x.1, &y.1) {
x
} else {
y
}
}

fn percolate_down(&mut self, mut i: u32) {
let x = self.heap[i as usize];
while (left_index(i) as usize) < self.heap.len() {
let child = if (right_index(i) as usize) < self.heap.len()
&& self.comp.lt(
&self.heap[right_index(i) as usize],
&self.heap[left_index(i) as usize],
) {
right_index(i)
} else {
left_index(i)
};
if !self.comp.lt(&self.heap[child as usize], &x) {
let len = (self.next_slot + 3) & (usize::MAX - 3); // round up to nearest multiple of 4
// since the heap is padded with maximum values we can pretend that these are part of the
// heap but never swap with them
assert!(len <= self.heap.len()); // hopefully this lets us eliminate bounds checks
while (right_index(i) as usize) < len {
let left_index = left_index(i);
let b0 = self.bundle(left_index);
let b1 = self.bundle(left_index + 1);
let b2 = self.bundle(left_index + 2);
let b3 = self.bundle(left_index + 3);
let b01 = self.min(b0, b1);
let b23 = self.min(b2, b3);
let (child, min) = self.min(b01, b23);
if !self.comp.lt(&min, &x) {
break;
}
self.heap[i as usize] = self.heap[child as usize];
let tmp = self.heap[i as usize];
self.indices[tmp.0] = i as i32;
self.heap[i as usize] = min;
self.indices[min.0] = i as i32;
i = child;
}
self.heap[i as usize] = x;
Expand All @@ -158,36 +216,45 @@ impl<'a, K: AsIndex + 'a, V: Copy + 'a, Comp: MemoComparator<K, V>> Heap<'a, K,
pub fn insert(&mut self, k: K) {
self.indices.reserve(k, -1);
debug_assert!(!self.in_heap(k));

self.indices[k] = self.heap.len() as i32;
self.data.heap.push((k, self.comp.value(k)));
let k_index = self.indices[k];
self.percolate_up(k_index as u32);
let k_index = self.heap_push(k, self.comp.value(k));
self.indices[k] = k_index as i32;
self.percolate_up(k_index);
}

pub fn remove_min(&mut self) -> K {
let x = *self.heap.first().expect("heap is empty");
let lastval = *self.heap.last().expect("heap is empty");
self.heap[0] = lastval;
self.indices[lastval.0] = 0;
assert!(!self.is_empty(), "cannot pop from empty heap");
let x = self.heap[ROOT as usize];
let last = self.next_slot - 1;
self.next_slot = last;
self.indices[x.0] = -1;
self.heap.pop().expect("cannot pop from empty heap");
if self.heap.len() > 1 {
self.percolate_down(0);
if self.is_empty() {
self.heap[last] = self.comp.max_value();
return x.0;
}
let lastval = self.heap[last];
self.heap[last] = self.comp.max_value();
self.heap[ROOT as usize] = lastval;
self.indices[lastval.0] = ROOT as i32;
self.percolate_down(ROOT);
x.0
}
}

/// Root of the quaternary heap
/// By using 3 as the root we ensure each chunk of 4 children has a multiple of 4 starting index
/// This gives the chunks a better chance of being cache aligned, i.e. they are cache aligned if
/// the allocation is cache aligned
const ROOT: u32 = 3;

#[inline(always)]
fn left_index(i: u32) -> u32 {
i * 2 + 1
(i - 2) << 2
}
#[inline(always)]
fn right_index(i: u32) -> u32 {
(i + 1) * 2
left_index(i) + 3
}
#[inline(always)]
fn parent_index(i: u32) -> u32 {
(i.wrapping_sub(1)) >> 1
(i >> 2) + 2
}

0 comments on commit dbe576e

Please sign in to comment.