diff --git a/libs/segtree/src/lib.rs b/libs/segtree/src/lib.rs index 75a075fa..9c8e692f 100644 --- a/libs/segtree/src/lib.rs +++ b/libs/segtree/src/lib.rs @@ -1,3 +1,4 @@ +use std::iter::FromIterator; use std::ops::Deref; use std::ops::DerefMut; use std::ops::Index; @@ -17,6 +18,13 @@ pub struct Segtree { values: Vec, } impl Segtree { + pub fn from_len(n: usize) -> Self + where + O::Value: Clone, + { + Self::new(&vec![O::identity(); n]) + } + pub fn new(values: &[O::Value]) -> Self where O::Value: Clone, @@ -31,7 +39,7 @@ impl Segtree { Self { values } } - pub fn fold>(&mut self, range: R) -> O::Value { + pub fn fold>(&self, range: R) -> O::Value { let n = self.values.len() / 2; let (mut start, mut end) = open(range, n); start += n; @@ -66,6 +74,15 @@ impl Segtree { } } +impl FromIterator for Segtree +where + O::Value: Clone, +{ + fn from_iter>(iter: I) -> Self { + Self::new(&iter.into_iter().collect::>()) + } +} + impl Index for Segtree { type Output = O::Value; @@ -82,10 +99,10 @@ impl<'a, O: Op> Drop for Entry<'a, O> { fn drop(&mut self) { let mut index = self.index; while index != 0 { - index >>= 1; + index /= 2; self.segtree.values[index] = O::op( - &self.segtree.values[index << 1], - &self.segtree.values[(index << 1) + 1], + &self.segtree.values[index * 2], + &self.segtree.values[index * 2 + 1], ); } } @@ -103,6 +120,298 @@ impl<'a, O: Op> DerefMut for Entry<'a, O> { } } +pub struct SparseSegtree { + inner: Segtree, + keys: Vec, +} +impl SparseSegtree { + pub fn new(kv: &[(K, O::Value)]) -> Self + where + K: Clone, + O::Value: Clone, + { + let mut keys = kv.iter().map(|(k, _)| k.clone()).collect::>(); + keys.sort(); + let values = kv.iter().map(|(_, v)| v.clone()).collect::>(); + Self { + inner: Segtree::new(&values), + keys: keys.to_vec(), + } + } + + pub fn fold>(&self, range: R) -> O::Value { + let (start, end) = open_key(range, &self.keys); + self.inner.fold(start..end) + } + + pub fn entry(&mut self, key: K) -> Entry<'_, O> { + let index = self.keys.binary_search(&key).unwrap() + self.keys.len(); + Entry { + segtree: &mut self.inner, + index, + } + } +} + +impl FromIterator<(K, O::Value)> for SparseSegtree +where + K: Clone, + O::Value: Clone, +{ + fn from_iter>(iter: I) -> Self { + Self::new(&iter.into_iter().collect::>()) + } +} + +impl Index for SparseSegtree { + type Output = O::Value; + + fn index(&self, key: K) -> &Self::Output { + &self.inner[self.keys.binary_search(&key).unwrap()] + } +} + +pub struct SegtreeOfSegtrees { + segtrees: Vec>, + keys: Vec, +} +impl SegtreeOfSegtrees +where + K: Ord + Clone, + L: Ord + Clone, + O::Value: Clone, +{ + pub fn new(points: &[(K, L, O::Value)]) -> Self { + let mut keys = points.iter().map(|(k, _, _)| k.clone()).collect::>(); + keys.sort(); + keys.dedup(); + let mut lvs = vec![vec![]; keys.len() * 2]; + for (k, l, v) in points { + let mut i = keys.binary_search(k).unwrap(); + i += keys.len(); + while i != 0 { + lvs[i].push((l.clone(), v.clone())); + i /= 2; + } + } + let segtrees = lvs + .into_iter() + .map(|lvs_| { + let mut ls = lvs_.iter().map(|(l, _)| l).collect::>(); + ls.sort(); + ls.dedup(); + let mut lvs = ls + .iter() + .map(|&l| (l.clone(), O::identity())) + .collect::>(); + for (l, v) in &lvs_ { + let i = ls.binary_search(&l).unwrap(); + lvs[i].1 = O::op(&lvs[i].1, v); + } + SparseSegtree::new(&lvs) + }) + .collect::>(); + Self { segtrees, keys } + } + + pub fn fold(&self, i: impl RangeBounds, j: impl RangeBounds + Clone) -> O::Value { + let (mut i0, mut i1) = open_key(i, &self.keys); + i0 += self.keys.len(); + i1 += self.keys.len(); + let mut left = O::identity(); + let mut right = O::identity(); + while i0 < i1 { + if i0 % 2 == 1 { + left = O::op(&left, &self.segtrees[i0].fold(j.clone())); + i0 += 1; + } + if i1 % 2 == 1 { + i1 -= 1; + right = O::op(&self.segtrees[i1].fold(j.clone()), &right); + } + i0 /= 2; + i1 /= 2; + } + O::op(&left, &right) + } + + pub fn apply(&mut self, k: K, l: L, mut f: impl FnMut(&mut O::Value)) { + let mut i = self.keys.binary_search(&k).unwrap(); + i += self.keys.len(); + while i != 0 { + f(&mut self.segtrees[i].entry(l.clone())); + i /= 2; + } + } +} + +impl FromIterator<(K, L, O::Value)> for SegtreeOfSegtrees +where + K: Ord + Clone, + L: Ord + Clone, + O::Value: Clone, +{ + fn from_iter>(iter: I) -> Self { + Self::new(&iter.into_iter().collect::>()) + } +} + +impl Index for SegtreeOfSegtrees { + type Output = SparseSegtree; + + fn index(&self, i: K) -> &Self::Output { + &self.segtrees[self.keys.binary_search(&i).unwrap() + self.keys.len()] + } +} + +impl Index<(K, L)> for SegtreeOfSegtrees { + type Output = O::Value; + + fn index(&self, (i, j): (K, L)) -> &Self::Output { + &self.segtrees[self.keys.binary_search(&i).unwrap() + self.keys.len()][j] + } +} + +pub struct Dense2dSegtree { + values: Vec>, +} +impl Dense2dSegtree { + pub fn new(values: &[Vec]) -> Self + where + O::Value: Clone, + { + let values_ = values; + let h = values.len(); + let w = values.get(0).map_or(0, |v| v.len()); + let mut values = vec![vec![O::identity(); 2 * w]; 2 * h]; + for (values, values_) in values[h..].iter_mut().zip(values_) { + values[w..].clone_from_slice(values_); + for j in (1..w).rev() { + values[j] = O::op(&values[j * 2], &values[j * 2 + 1]); + } + } + for i in (1..h).rev() { + for j in 0..2 * w { + values[i][j] = O::op(&values[i * 2][j], &values[i * 2 + 1][j]); + } + } + Self { values } + } + + pub fn fold(&self, i: impl RangeBounds, j: impl RangeBounds) -> O::Value { + let h = self.values.len() / 2; + let w = self.values.get(0).map_or(0, |v| v.len() / 2); + let (mut i0, mut i1) = open(i, h); + let (mut j0, mut j1) = open(j, w); + i0 += h; + i1 += h; + j0 += w; + j1 += w; + let mut left = O::identity(); + let mut right = O::identity(); + while i0 < i1 { + if i0 % 2 == 1 { + let mut j0 = j0; + let mut j1 = j1; + while j0 < j1 { + if j0 % 2 == 1 { + left = O::op(&left, &self.values[i0][j0]); + j0 += 1; + } + if j1 % 2 == 1 { + j1 -= 1; + right = O::op(&self.values[i0][j1], &right); + } + j0 /= 2; + j1 /= 2; + } + i0 += 1; + } + if i1 % 2 == 1 { + i1 -= 1; + let mut j0 = j0; + let mut j1 = j1; + while j0 < j1 { + if j0 % 2 == 1 { + left = O::op(&left, &self.values[i1][j0]); + j0 += 1; + } + if j1 % 2 == 1 { + j1 -= 1; + right = O::op(&self.values[i1][j1], &right); + } + j0 /= 2; + j1 /= 2; + } + } + i0 /= 2; + i1 /= 2; + } + O::op(&left, &right) + } + + pub fn entry(&mut self, i: usize, j: usize) -> Dense2dEntry { + let h = self.values.len() / 2; + let w = self.values.get(0).map_or(0, |v| v.len() / 2); + Dense2dEntry { + segtree: self, + i: h + i, + j: w + j, + } + } +} + +impl Index for Dense2dSegtree { + type Output = [O::Value]; + + fn index(&self, index: usize) -> &Self::Output { + &self.values[self.values.len() / 2 + index] + } +} + +pub struct Dense2dEntry<'a, O: Op> { + segtree: &'a mut Dense2dSegtree, + i: usize, + j: usize, +} +impl<'a, O: Op> Drop for Dense2dEntry<'a, O> { + fn drop(&mut self) { + let mut i = self.i; + let mut j = self.j / 2; + while j != 0 { + self.segtree.values[i][j] = O::op( + &self.segtree.values[i][2 * j], + &self.segtree.values[i][2 * j + 1], + ); + j /= 2; + } + i /= 2; + while i != 0 { + let mut j = self.j; + while j != 0 { + self.segtree.values[i][j] = O::op( + &self.segtree.values[i * 2][j], + &self.segtree.values[i * 2 + 1][j], + ); + j /= 2; + } + i /= 2; + } + } +} +impl<'a, O: Op> Deref for Dense2dEntry<'a, O> { + type Target = O::Value; + + fn deref(&self) -> &Self::Target { + &self.segtree.values[self.i][self.j] + } +} +impl<'a, O: Op> DerefMut for Dense2dEntry<'a, O> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.segtree.values[self.i][self.j] + } +} + fn open>(bounds: B, n: usize) -> (usize, usize) { use std::ops::Bound; let start = match bounds.start_bound() { @@ -118,6 +427,21 @@ fn open>(bounds: B, n: usize) -> (usize, usize) { (start, end) } +fn open_key>(bounds: B, keys: &[K]) -> (usize, usize) { + use std::ops::Bound; + let start = match bounds.start_bound() { + Bound::Unbounded => 0, + Bound::Included(x) => keys.binary_search(x).unwrap_or_else(|i| i), + Bound::Excluded(x) => keys.binary_search(x).unwrap_or_else(|i| i + 1), + }; + let end = match bounds.end_bound() { + Bound::Unbounded => keys.len(), + Bound::Included(x) => keys.binary_search(x).map_or_else(|i| i + 1, |i| i + 1), + Bound::Excluded(x) => keys.binary_search(x).unwrap_or_else(|i| i), + }; + (start, end) +} + #[cfg(test)] mod tests { use super::*; @@ -127,23 +451,44 @@ mod tests { use std::iter::repeat_with; use std::ops::Range; - const P: u64 = 998244353; - const BASE: u64 = 10; - enum O {} - impl Op for O { - type Value = (u64, u64); + mod rolling_hash { + use super::*; + pub const P: u64 = 998244353; + pub const BASE: u64 = 10; + pub enum O {} + impl Op for O { + type Value = (u64, u64); - fn identity() -> Self::Value { - (0, 1) + fn identity() -> Self::Value { + (0, 1) + } + + fn op(lhs: &Self::Value, rhs: &Self::Value) -> Self::Value { + ((lhs.0 * rhs.1 + rhs.0) % P, lhs.1 * rhs.1 % P) + } } + } + + mod xor { + use super::*; + pub enum O {} + impl Op for O { + type Value = u64; + + fn identity() -> Self::Value { + 0 + } - fn op(lhs: &Self::Value, rhs: &Self::Value) -> Self::Value { - ((lhs.0 * rhs.1 + rhs.0) % P, lhs.1 * rhs.1 % P) + fn op(lhs: &Self::Value, rhs: &Self::Value) -> Self::Value { + lhs ^ rhs + } } } #[test] - fn test() { + fn test_segtree() { + use rolling_hash::*; + let mut rng = StdRng::seed_from_u64(42); for _ in 0..100 { let n = rng.gen_range(1..=100); @@ -176,6 +521,180 @@ mod tests { } } + #[test] + fn test_segtree_usability() { + use rolling_hash::*; + let _ = Segtree::::from_len(1); + let _ = Segtree::::new(&[(0, 1)]); + let _ = Segtree::::from_iter(vec![(0, 1)]); + let mut segtree = Segtree::::new(&[(0, 1)]); + let _ = segtree.fold(0..1); + let _ = segtree.entry(0); + assert_eq!(segtree.as_slice()[0], (0, 1)); + assert_eq!(segtree[0], (0, 1)); + } + + #[test] + fn test_sparse_segtree() { + use rolling_hash::*; + + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..100 { + let n = rng.gen_range(1..=100); + let q = rng.gen_range(1..=100); + let mut keys = repeat_with(|| rng.gen_range(0..n)) + .take(n) + .collect::>(); + keys.sort_unstable(); + let mut vec = keys + .iter() + .copied() + .map(|key| (key, (rng.gen_range(0..BASE), BASE))) + .collect::>(); + let mut segtree = SparseSegtree::::from_iter(vec.iter().copied()); + for _ in 0..q { + match rng.gen_range(0..2) { + // fold + 0 => { + let range = random_range(&mut rng, n); + let start = keys.binary_search(&range.start).unwrap_or_else(|i| i); + let end = keys.binary_search(&range.end).unwrap_or_else(|i| i); + let expected = vec[start..end] + .iter() + .map(|(_, x)| x) + .fold(O::identity(), |acc, x| O::op(&acc, x)); + let result = segtree.fold(range.clone()); + assert_eq!(expected, result); + } + // update + 1 => { + let k = rng.gen_range(0..n); + let x = (rng.gen_range(0..BASE), BASE); + match keys.binary_search(&k) { + Ok(j) => { + *segtree.entry(k) = x; + vec[j].1 = x; + } + Err(_) => {} + } + } + _ => unreachable!(), + } + } + } + } + + #[test] + fn test_sparse_segtree_usability() { + use rolling_hash::*; + let _ = SparseSegtree::::new(&[(0, (1, 1))]); + let _ = SparseSegtree::::from_iter(vec![(0, (1, 1))]); + let mut segtree = SparseSegtree::::new(&[(0, (1, 1))]); + let _ = segtree.fold(0..1); + let _ = segtree.entry(0); + assert_eq!(segtree[0], (1, 1)); + } + + #[test] + fn test_segtree_of_segtree() { + use xor::*; + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..30 { + let h = rng.gen_range(1..=20); + let w = rng.gen_range(1..=20); + let n = rng.gen_range(1..=100); + let q = rng.gen_range(1..=400); + let mut vec = repeat_with(|| { + ( + rng.gen_range(0..h), + rng.gen_range(0..w), + rng.gen_range(0..u64::MAX), + ) + }) + .take(n) + .collect::>(); + let mut segtree = SegtreeOfSegtrees::::new(&vec); + for _ in 0..q { + match rng.gen_range(0..1) { + // fold + 0 => { + let i = random_range(&mut rng, h); + let j = random_range(&mut rng, w); + let expected = vec + .iter() + .filter(|(x, y, _)| i.contains(x) && j.contains(y)) + .map(|(_, _, v)| v) + .fold(O::identity(), |acc, x| O::op(&acc, x)); + let result = segtree.fold(i.clone(), j.clone()); + assert_eq!(expected, result); + } + // update + 1 => { + let k = rng.gen_range(0..n); + let y = rng.gen_range(0..u64::MAX); + let (i, j, x) = vec[k]; + vec[k].2 = x ^ y; + segtree.apply(i, j, |v| *v ^= y); + } + _ => unreachable!(), + } + } + } + } + + #[test] + fn test_dense_2d_segtree() { + use xor::*; + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..20 { + let h = rng.gen_range(1..=10); + let w = rng.gen_range(1..=10); + let q = rng.gen_range(1..=100); + let mut vec = repeat_with(|| { + repeat_with(|| rng.gen_range(0..u64::MAX)) + .take(w) + .collect::>() + }) + .take(h) + .collect::>(); + let mut segtree = Dense2dSegtree::::new(&vec); + for _ in 0..q { + match rng.gen_range(0..2) { + // fold + 0 => { + let i = random_range(&mut rng, h); + let j = random_range(&mut rng, w); + let expected = vec[i.clone()] + .iter() + .flat_map(|v| v[j.clone()].iter()) + .fold(O::identity(), |acc, x| O::op(&acc, x)); + let result = segtree.fold(i.clone(), j.clone()); + assert_eq!(expected, result); + } + // update + 1 => { + let i = rng.gen_range(0..h); + let j = rng.gen_range(0..w); + let x = rng.gen_range(0..u64::MAX); + vec[i][j] = x; + *segtree.entry(i, j) = x; + } + _ => unreachable!(), + } + } + } + } + + #[test] + fn test_dense_2d_segtree_usability() { + use xor::*; + let _ = Dense2dSegtree::::new(&vec![vec![0]]); + let mut segtree = Dense2dSegtree::::new(&vec![vec![0]]); + let _ = segtree.fold(0..1, 0..1); + let _ = segtree.entry(0, 0); + assert_eq!(segtree[0][0], 0); + } + fn random_range(rng: &mut StdRng, n: usize) -> Range { let start = rng.gen_range(0..=n + 1); let end = rng.gen_range(0..=n);