Skip to content

Commit

Permalink
refactor & power-of-2-ize segtree
Browse files Browse the repository at this point in the history
  • Loading branch information
ngtkana committed Jul 31, 2024
1 parent 9e7e282 commit 6b9ea3b
Showing 1 changed file with 57 additions and 41 deletions.
98 changes: 57 additions & 41 deletions libs/segtree/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,49 +35,64 @@ pub trait Op {

/// A segment tree.
pub struct Segtree<O: Op> {
len: usize,
offset: usize,
values: Vec<O::Value>,
}
impl<O: Op> Segtree<O> {
/// Constructs a new segment tree with the specified length.
pub fn from_len(n: usize) -> Self
pub fn from_len(len: usize) -> Self
where
O::Value: Clone,
{
Self::new(&vec![O::identity(); n])
let offset = len.next_power_of_two();
Self {
len,
offset,
values: vec![O::identity(); 2 * offset],
}
}

/// Constructs with the specified values.
pub fn new(values: &[O::Value]) -> Self
pub fn new(elms: &[O::Value]) -> Self
where
O::Value: Clone,
{
let values_ = values;
let n = values.len();
let mut values = vec![O::identity(); 2 * n];
values[n..].clone_from_slice(values_);
for i in (1..n).rev() {
let len = elms.len();
let offset = len.next_power_of_two();
let mut values = vec![O::identity(); 2 * offset];
values[offset..offset + len].clone_from_slice(elms);
for i in (1..offset).rev() {
values[i] = O::op(&values[i * 2], &values[i * 2 + 1]);
}
Self { values }
Self {
len,
offset,
values,
}
}

/// Returns $x_l \cdot x_{l+1} \cdot \ldots \cdot x_{r-1}$.
pub fn fold<R: RangeBounds<usize>>(&self, range: R) -> O::Value {
let n = self.values.len() / 2;
let (mut start, mut end) = open(range, n);
assert!(start <= end && end <= n);
start += n;
end += n;
let Self {
len,
offset,
ref values,
} = *self;
let (mut start, mut end) = open(range, len);
assert!((start..=len).contains(&len));
start += offset;
end += offset;
let mut left = O::identity();
let mut right = O::identity();
while start < end {
if start % 2 == 1 {
left = O::op(&left, &self.values[start]);
left = O::op(&left, &values[start]);
start += 1;
}
if end % 2 == 1 {
end -= 1;
right = O::op(&self.values[end], &right);
right = O::op(&values[end], &right);
}
start /= 2;
end /= 2;
Expand All @@ -87,21 +102,21 @@ impl<O: Op> Segtree<O> {

/// Returns the entry of $x_i$.
pub fn entry(&mut self, index: usize) -> Entry<O> {
let n = self.values.len() / 2;
let offset = self.offset;
Entry {
segtree: self,
index: n + index,
index: offset + index,
}
}

/// Returns an iterator of $x_0, x_1, \ldots, x_{n-1}$.
pub fn iter(&self) -> impl Iterator<Item = &O::Value> {
self.values[self.values.len() / 2..].iter()
pub fn iter(&self) -> std::slice::Iter<O::Value> {
self.values[self.offset..self.offset + self.len].iter()
}

/// Returns a slice of $x_0, x_1, \ldots, x_{n-1}$.
pub fn as_slice(&self) -> &[O::Value] {
&self.values[self.values.len() / 2..]
&self.values[self.offset..self.offset + self.len]
}
}

Expand Down Expand Up @@ -172,11 +187,12 @@ where
}

/// A sparse (compressed) segment tree.
pub struct SparseSegtree<K, O: Op> {
/// Use [`collect_map()`](Self::collect_map()) to debug.
pub struct SegtreeWithCompression<K, O: Op> {
inner: Segtree<O>,
keys: Vec<K>,
}
impl<K: Ord, O: Op> SparseSegtree<K, O> {
impl<K: Ord, O: Op> SegtreeWithCompression<K, O> {
/// Constructs with the specified key-value pairs.
pub fn new(kv: &[(K, O::Value)]) -> Self
where
Expand All @@ -194,18 +210,15 @@ impl<K: Ord, O: Op> SparseSegtree<K, O> {

/// Folds $\left \lbrace x_k \mid k \in \text{{range}} \right \rbrace$.
pub fn fold<R: RangeBounds<K>>(&self, range: R) -> O::Value {
let (start, end) = open_key(range, &self.keys);
let (start, end) = to_range(range, &self.keys);
self.inner.fold(start..end)
}

/// Returns the entry of $x_k$.
/// If $k$ is not found, it panics.
pub fn entry(&mut self, key: &K) -> Entry<'_, O> {
let index = self.keys.binary_search(key).unwrap() + self.keys.len();
Entry {
segtree: &mut self.inner,
index,
}
let index = self.keys.binary_search(key).unwrap();
self.inner.entry(index)
}

/// Returns the keys.
Expand All @@ -232,7 +245,7 @@ impl<K: Ord, O: Op> SparseSegtree<K, O> {
}
}

impl<K, O: Op> fmt::Debug for SparseSegtree<K, O>
impl<K, O: Op> fmt::Debug for SegtreeWithCompression<K, O>
where
K: fmt::Debug,
O::Value: fmt::Debug,
Expand All @@ -245,7 +258,7 @@ where
}
}

impl<K: Ord, O: Op> FromIterator<(K, O::Value)> for SparseSegtree<K, O>
impl<K: Ord, O: Op> FromIterator<(K, O::Value)> for SegtreeWithCompression<K, O>
where
K: Clone,
O::Value: Clone,
Expand All @@ -255,7 +268,7 @@ where
}
}

impl<K: Ord, O: Op> Index<K> for SparseSegtree<K, O> {
impl<K: Ord, O: Op> Index<K> for SegtreeWithCompression<K, O> {
type Output = O::Value;

fn index(&self, key: K) -> &Self::Output {
Expand All @@ -266,7 +279,7 @@ impl<K: Ord, O: Op> Index<K> for SparseSegtree<K, O> {
/// A segment tree of segment trees (2D segment tree).
/// The multiplication must be commutative.
pub struct Sparse2dSegtree<K, L, O: Op> {
segtrees: Vec<SparseSegtree<L, O>>,
segtrees: Vec<SegtreeWithCompression<L, O>>,
keys: Vec<K>,
}
impl<K, L, O: Op> Sparse2dSegtree<K, L, O>
Expand Down Expand Up @@ -303,15 +316,15 @@ where
let i = ls.binary_search(&l).unwrap();
lvs[i].1 = O::op(&lvs[i].1, v);
}
SparseSegtree::new(&lvs)
SegtreeWithCompression::new(&lvs)
})
.collect::<Vec<_>>();
Self { segtrees, keys }
}

/// Folds $\left \lbrace x_{k, l} \mid (k, l) \in \text{{range}} \right \rbrace$.
pub fn fold(&self, i: impl RangeBounds<K>, j: impl RangeBounds<L> + Clone) -> O::Value {
let (mut i0, mut i1) = open_key(i, &self.keys);
let (mut i0, mut i1) = to_range(i, &self.keys);
i0 += self.keys.len();
i1 += self.keys.len();
let mut left = O::identity();
Expand Down Expand Up @@ -395,7 +408,7 @@ where
}

impl<K: Ord, L: Ord, O: Op> Index<K> for Sparse2dSegtree<K, L, O> {
type Output = SparseSegtree<L, O>;
type Output = SegtreeWithCompression<L, O>;

fn index(&self, i: K) -> &Self::Output {
&self.segtrees[self.keys.binary_search(&i).unwrap() + self.keys.len()]
Expand Down Expand Up @@ -598,7 +611,7 @@ fn open<B: RangeBounds<usize>>(bounds: B, n: usize) -> (usize, usize) {
(start, end)
}

fn open_key<K: Ord, B: RangeBounds<K>>(bounds: B, keys: &[K]) -> (usize, usize) {
fn to_range<K: Ord, B: RangeBounds<K>>(bounds: B, keys: &[K]) -> (usize, usize) {
use std::ops::Bound;
let start = match bounds.start_bound() {
Bound::Unbounded => 0,
Expand Down Expand Up @@ -734,7 +747,10 @@ mod tests {
.copied()
.map(|key| (key, (rng.gen_range(0..BASE), BASE)))
.collect::<Vec<_>>();
let mut segtree = vec.iter().copied().collect::<SparseSegtree<_, O>>();
let mut segtree = vec
.iter()
.copied()
.collect::<SegtreeWithCompression<_, O>>();
for _ in 0..q {
match rng.gen_range(0..2) {
// fold
Expand Down Expand Up @@ -767,9 +783,9 @@ mod tests {
#[test]
fn test_sparse_segtree_usability() {
use rolling_hash::O;
let _ = SparseSegtree::<usize, O>::new(&[(0, (1, 1))]);
let _ = SparseSegtree::<usize, O>::from_iter(vec![(0, (1, 1))]);
let mut segtree = SparseSegtree::<usize, O>::new(&[(0, (1, 1))]);
let _ = SegtreeWithCompression::<usize, O>::new(&[(0, (1, 1))]);
let _ = SegtreeWithCompression::<usize, O>::from_iter(vec![(0, (1, 1))]);
let mut segtree = SegtreeWithCompression::<usize, O>::new(&[(0, (1, 1))]);
let _ = segtree.fold(0..1);
let _ = segtree.entry(&0);
assert_eq!(segtree[0], (1, 1));
Expand Down

0 comments on commit 6b9ea3b

Please sign in to comment.