Skip to content

Commit

Permalink
Add Seg::{nth,nth_mut}
Browse files Browse the repository at this point in the history
  • Loading branch information
ngtkana committed Nov 10, 2023
1 parent 7f1d902 commit 7fbce3c
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 23 deletions.
33 changes: 25 additions & 8 deletions libs/rb/src/balance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ pub mod test_utils {
black_height: u8,
red_vio: bool,
black_vio: bool,
force_black_leaf: bool,
) -> (Self, Violations<T>) {
#[allow(clippy::too_many_lines)]
fn random_tree<T: Balance>(
Expand All @@ -531,8 +532,12 @@ pub mod test_utils {
mut black_height: u8,
red_vio: bool,
black_vio: bool,
force_black_leaf: bool,
parent: Option<Ptr<T>>,
) -> (Option<Ptr<T>>, Violations<T>) {
if black_height == 0 && force_black_leaf {
return (None, Violations::default());
}
// Select the violation position here
let parent_color = parent.map_or(Color::Black, |mut p| *p.color());
let here_red_vio;
Expand Down Expand Up @@ -632,6 +637,7 @@ pub mod test_utils {
children_h,
left_red_vio,
left_black_vio,
force_black_leaf,
Some(here),
);
vios.append(left_vio);
Expand All @@ -641,6 +647,7 @@ pub mod test_utils {
children_h,
right_red_vio,
right_black_vio,
force_black_leaf,
Some(here),
);
vios.append(right_vio);
Expand All @@ -656,8 +663,15 @@ pub mod test_utils {
}
(here, vios)
}
let (root, vios) =
random_tree(rng, &mut new_node, black_height, red_vio, black_vio, None);
let (root, vios) = random_tree(
rng,
&mut new_node,
black_height,
red_vio,
black_vio,
force_black_leaf,
None,
);
(Tree { root, black_height }, vios)
}

Expand Down Expand Up @@ -846,7 +860,8 @@ mod test_fix {
let mut rng = StdRng::seed_from_u64(0);
for _ in 0..200 {
let h = rng.gen_range(0..=4);
let (tree, expected_violations) = Tree::random(&mut rng, new_node, h, false, false);
let (tree, expected_violations) =
Tree::random(&mut rng, new_node, h, false, false, false);
assert_eq!(tree.black_height, h, "{}", tree);
assert_eq!(expected_violations.red_vios.len(), 0, "{}", tree);
assert_eq!(expected_violations.black_vios.len(), 0, "{}", tree);
Expand All @@ -862,7 +877,8 @@ mod test_fix {
let mut rng = StdRng::seed_from_u64(0);
for _ in 0..200 {
let h = rng.gen_range(0..=4);
let (tree, expected_violations) = Tree::random(&mut rng, new_node, h, true, false);
let (tree, expected_violations) =
Tree::random(&mut rng, new_node, h, true, false, false);
assert_eq!(tree.black_height, h, "{}", tree);
assert_eq!(expected_violations.red_vios.len(), 1, "{}", tree);
assert_eq!(expected_violations.black_vios.len(), 0, "{}", tree);
Expand All @@ -878,7 +894,8 @@ mod test_fix {
let mut rng = StdRng::seed_from_u64(0);
for _ in 0..200 {
let h = rng.gen_range(1..=4);
let (tree, expected_violations) = Tree::random(&mut rng, new_node, h, false, true);
let (tree, expected_violations) =
Tree::random(&mut rng, new_node, h, false, true, false);
assert_eq!(tree.black_height, h, "{}", tree);
assert_eq!(expected_violations.red_vios.len(), 0, "{}", tree);
assert_eq!(expected_violations.black_vios.len(), 1, "{}", tree);
Expand All @@ -894,7 +911,7 @@ mod test_fix {
let mut rng = StdRng::seed_from_u64(0);
for _ in 0..200 {
let h = rng.gen_range(1..=4);
let (mut tree, vios) = Tree::random(&mut rng, new_node, h, true, false);
let (mut tree, vios) = Tree::random(&mut rng, new_node, h, true, false, false);
let before = tree.collect();

tree.fix_red(vios.red_vios[0]);
Expand All @@ -913,7 +930,7 @@ mod test_fix {
let mut rng = StdRng::seed_from_u64(0);
for _ in 0..200 {
let h = rng.gen_range(1..=4);
let (mut tree, vios) = Tree::random(&mut rng, new_node, h, false, true);
let (mut tree, vios) = Tree::random(&mut rng, new_node, h, false, true, false);
let before = tree.collect();

tree.fix_black(vios.black_vios[0]);
Expand Down Expand Up @@ -991,7 +1008,7 @@ mod test_update {
}
}

let (tree, vios) = Self::random(rng, new_node, black_height, red_vio, black_vio);
let (tree, vios) = Self::random(rng, new_node, black_height, red_vio, black_vio, false);
update(tree.root);
(tree, vios)
}
Expand Down
144 changes: 129 additions & 15 deletions libs/rb/src/seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::balance::Tree;
use std::cmp::Reverse;
use std::fmt;
use std::marker::PhantomData;
use std::ops;
use std::ops::Bound;
use std::ops::Range;
use std::ops::RangeBounds;
Expand Down Expand Up @@ -139,6 +140,15 @@ impl<O: Op> Seg<O> {

pub fn table(&self) -> SegTable<'_, O> { SegTable(self) }

pub fn nth(&self, index: usize) -> &O::Value { &self.nth_ptr(index).as_longlife_ref().value }

pub fn nth_mut(&mut self, index: usize) -> Entry<'_, O> {
Entry {
p: self.nth_ptr(index),
marker: PhantomData,
}
}

pub fn fold(&self, range: impl RangeBounds<usize>) -> Option<O::Value> {
let (start, end) = into_range(range, self.len());
assert!(
Expand Down Expand Up @@ -269,6 +279,21 @@ impl<O: Op> Seg<O> {
other.tree,
);
}

fn nth_ptr(&self, mut index: usize) -> Ptr<Node<O>> {
assert!(index < self.len());
let mut x = self.tree.root.unwrap();
while !x.is_leaf() {
let left_len = x.left.unwrap().len;
x = if index < left_len {
x.left.unwrap()
} else {
index -= left_len;
x.right.unwrap()
}
}
x
}
}
impl<O: Op> Default for Seg<O> {
fn default() -> Self { Self::new() }
Expand Down Expand Up @@ -340,6 +365,24 @@ impl<'a, O: Op> IntoIterator for &'a Seg<O> {

fn into_iter(self) -> Self::IntoIter { self.iter() }
}
pub struct Entry<'a, O: Op> {
p: Ptr<Node<O>>,
marker: PhantomData<&'a O>,
}
impl<'a, O: Op> ops::Deref for Entry<'a, O> {
type Target = O::Value;

fn deref(&self) -> &Self::Target { &self.p.as_longlife_ref().value }
}
impl<'a, O: Op> ops::DerefMut for Entry<'a, O> {
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.p.as_longlife_mut().value }
}
impl<'a, O: Op> Drop for Entry<'a, O> {
fn drop(&mut self) {
self.p.update();
self.p.update_ancestors();
}
}
struct SegTableCell<'a, O: Op> {
start: usize,
end: usize,
Expand Down Expand Up @@ -429,6 +472,10 @@ where
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
if rows.is_empty() {
writeln!(f, "SegTable (empty)")?;
return Ok(());
}
let n = rows[0].len();
let mut colomn_width = (0..n).map(|i| i.to_string().len()).collect::<Vec<_>>();
for (range, value) in rows.iter().flatten() {
Expand Down Expand Up @@ -588,6 +635,8 @@ mod test_seg {
use super::Op;
use super::Ptr;
use super::Seg;
use crate::balance::Balance;
use crate::balance::Tree;
use rand::distributions::Alphanumeric;
use rand::rngs::StdRng;
use rand::Rng;
Expand All @@ -605,6 +654,37 @@ mod test_seg {
}

impl Seg<O> {
fn random(rng: &mut StdRng, black_height: u8) -> Self {
fn update_all(mut p: Ptr<Node<O>>) {
if p.left.is_some() {
update_all(p.left.unwrap());
update_all(p.right.unwrap());
p.update();
}
}
let (tree, _) = Tree::random(
rng,
|rng, color| {
Node::new(
rng.sample_iter(&Alphanumeric)
.take(1)
.map(char::from)
.collect::<String>(),
color,
1,
)
},
black_height,
false,
false,
true,
);
if let Some(root) = tree.root {
update_all(root);
}
Self { tree }
}

fn validate_value(&self) {
fn validate_value(p: Option<Ptr<Node<O>>>) -> Result<(), String> {
if let Some(p) = p {
Expand All @@ -617,8 +697,8 @@ mod test_seg {
expected.push_str(&p.right.unwrap().value);
(p.value == expected).then_some(()).ok_or_else(|| {
format!(
"Len is incorrect at {:?}. Expected {}, but cached {}",
p, expected, p.len
"Value is incorrect at {:?}. Expected {}, but cached {}",
p, expected, &p.value
)
})?;
validate_value(p.right)?;
Expand Down Expand Up @@ -648,6 +728,53 @@ mod test_seg {
}
}

#[test]
fn test_seg_nth_mut() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..20 {
let black_height = rng.gen_range(1..=4);
let mut seg = Seg::random(&mut rng, black_height);
let mut vec = seg.iter().cloned().collect::<Vec<_>>();
for _ in 0..200 {
let i = rng.gen_range(0..seg.len());
let x = (&mut rng)
.sample_iter(&Alphanumeric)
.take(1)
.map(char::from)
.collect::<String>();
*seg.nth_mut(i) = x.clone();
vec[i] = x;
seg.tree.validate();
seg.validate_value();
assert_eq!(seg.iter().cloned().collect::<Vec<_>>(), vec);
}
}
}

#[test]
fn test_fold() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..20 {
let black_height = rng.gen_range(0..=4);
let seg = Seg::random(&mut rng, black_height);
seg.tree.validate();
seg.validate_value();
let n = seg.len();
let vec = seg.iter().cloned().collect::<Vec<_>>();
for _ in 0..200 {
let mut i = rng.gen_range(0..=n + 1);
let mut j = rng.gen_range(0..=n);
if i > j {
std::mem::swap(&mut i, &mut j);
j -= 1;
}
let result = seg.fold(i..j).unwrap_or_default();
let expected = vec[i..j].iter().flat_map(|s| s.chars()).collect::<String>();
assert_eq!(result, expected, "fold({i}..{j})");
}
}
}

#[test]
fn test_seg_insert() {
let mut rng = StdRng::seed_from_u64(42);
Expand All @@ -665,19 +792,6 @@ mod test_seg {
vec.insert(i, s);
assert_eq!(seg.iter().cloned().collect::<Vec<_>>(), vec);
seg.validate_value();

// Validate `fold`
{
let mut i = rng.gen_range(0..=vec.len() + 1);
let mut j = rng.gen_range(0..=vec.len());
if i > j {
std::mem::swap(&mut i, &mut j);
j -= 1;
}
let result = seg.fold(i..j).unwrap_or_default();
let expected = vec[i..j].iter().flat_map(|s| s.chars()).collect::<String>();
assert_eq!(result, expected, "fold({i}..{j})");
}
}
}
}
Expand Down

0 comments on commit 7fbce3c

Please sign in to comment.