Skip to content

Commit

Permalink
Seg::{from_iter, append}
Browse files Browse the repository at this point in the history
  • Loading branch information
ngtkana committed Nov 10, 2023
1 parent 20c6ac0 commit 7f1d902
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 21 deletions.
150 changes: 138 additions & 12 deletions libs/rb/src/balance.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use core::fmt;
use std::cmp::Ordering;
use std::ops::Deref;
use std::ops::DerefMut;
use std::ptr::NonNull;
Expand Down Expand Up @@ -235,6 +236,95 @@ impl<T: Balance> Tree<T> {
}
}
}
impl<T: Balance> Clone for Tree<T> {
fn clone(&self) -> Self {
Self {
root: self.root,
black_height: self.black_height,
}
}
}
impl<T: Balance> Copy for Tree<T> {}

pub fn join<T: Balance>(
mut left: Tree<T>,
center: impl FnOnce(Ptr<T>, Ptr<T>) -> Ptr<T>,
mut right: Tree<T>,
) -> Tree<T> {
debug_assert!(left.root.is_some());
debug_assert!(right.root.is_some());
match left.black_height.cmp(&right.black_height) {
Ordering::Less => {
if color(left.root) == Color::Red {
*left.root.unwrap().color() = Color::Black;
left.black_height += 1;
}
debug_assert!(left.black_height > 0);
let mut right1 = Tree {
root: right.root,
black_height: right.black_height,
};
while left.black_height < right1.black_height || color(right1.root) == Color::Red {
let mut root = right1.root.unwrap();
if *root.color() == Color::Black {
right1.black_height -= 1;
}
right1.root = Some(root.left().unwrap());
}
let mut center = center(left.root.unwrap(), right1.root.unwrap());
right.transplant(right1.root.unwrap(), Some(center));
*center.color() = Color::Red;
*center.left() = left.root;
*center.right() = right1.root;
*right1.root.unwrap().parent() = Some(center);
*left.root.unwrap().parent() = Some(center);
center.update();
right.fix_red(center);
right
}
Ordering::Greater => {
if color(right.root) == Color::Red {
*right.root.unwrap().color() = Color::Black;
right.black_height += 1;
}
debug_assert!(right.black_height > 0);
let mut left1 = Tree {
root: left.root,
black_height: left.black_height,
};
while left1.black_height > right.black_height || color(left1.root) == Color::Red {
let mut root = left1.root.unwrap();
if *root.color() == Color::Black {
left1.black_height -= 1;
}
left1.root = Some(root.right().unwrap());
}
let mut center = center(left1.root.unwrap(), right.root.unwrap());
left.transplant(left1.root.unwrap(), Some(center));
*center.color() = Color::Red;
*center.left() = left1.root;
*center.right() = right.root;
*left1.root.unwrap().parent() = Some(center);
*right.root.unwrap().parent() = Some(center);
center.update();
left.fix_red(center);
left
}
Ordering::Equal => {
let mut center = center(left.root.unwrap(), right.root.unwrap());
*center.color() = Color::Black;
*center.left() = left.root;
*center.right() = right.root;
*left.root.unwrap().parent() = Some(center);
*right.root.unwrap().parent() = Some(center);
center.update();
Tree {
root: Some(center),
black_height: left.black_height + 1,
}
}
}
}

pub struct BlackViolation<T: Balance> {
pub p: Option<Ptr<T>>,
Expand Down Expand Up @@ -535,11 +625,11 @@ pub mod test_utils {

// Recurse
if let Some(mut here) = here {
let children_black_height = black_height - u8::from(color == Color::Black);
let children_h = black_height - u8::from(color == Color::Black);
let (mut left, left_vio) = random_tree(
rng,
new_node,
children_black_height,
children_h,
left_red_vio,
left_black_vio,
Some(here),
Expand All @@ -548,7 +638,7 @@ pub mod test_utils {
let (mut right, right_vio) = random_tree(
rng,
new_node,
children_black_height,
children_h,
right_red_vio,
right_black_vio,
Some(here),
Expand Down Expand Up @@ -840,6 +930,7 @@ mod test_fix {

#[cfg(test)]
mod test_update {
use super::join;
use super::test_utils::Violations;
use super::Balance as _;
use super::Color;
Expand All @@ -851,16 +942,16 @@ mod test_update {

const VALUE_LIM: i32 = 20;

struct SumNode {
struct Node {
pub value: i32,
pub sum: i32,
pub color: Color,
pub parent: Option<Ptr<Self>>,
pub left: Option<Ptr<Self>>,
pub right: Option<Ptr<Self>>,
}
fn sum(p: Option<Ptr<SumNode>>) -> i32 { p.map_or(0, |p| p.sum) }
impl super::Balance for SumNode {
fn sum(p: Option<Ptr<Node>>) -> i32 { p.map_or(0, |p| p.sum) }
impl super::Balance for Node {
fn update(&mut self) { self.sum = sum(self.left) + self.value + sum(self.right); }

fn push(&mut self) {}
Expand All @@ -874,16 +965,16 @@ mod test_update {
fn right(&mut self) -> &mut Option<Ptr<Self>> { &mut self.right }
}

impl Tree<SumNode> {
impl Tree<Node> {
fn random_sum(
rng: &mut StdRng,
black_height: u8,
red_vio: bool,
black_vio: bool,
) -> (Self, Violations<SumNode>) {
fn new_node(rng: &mut StdRng, color: Color) -> SumNode {
) -> (Self, Violations<Node>) {
fn new_node(rng: &mut StdRng, color: Color) -> Node {
let value = rng.gen_range(0..VALUE_LIM);
SumNode {
Node {
value,
sum: value,
color,
Expand All @@ -892,7 +983,7 @@ mod test_update {
right: None,
}
}
fn update(p: Option<Ptr<SumNode>>) {
fn update(p: Option<Ptr<Node>>) {
if let Some(mut p) = p {
update(p.left);
update(p.right);
Expand All @@ -906,7 +997,7 @@ mod test_update {
}

fn validate_sum(&self) {
fn validate_sum(p: Option<Ptr<SumNode>>) -> Result<(), String> {
fn validate_sum(p: Option<Ptr<Node>>) -> Result<(), String> {
if let Some(p) = p {
validate_sum(p.left)?;
let expected = sum(p.left) + p.value + sum(p.right);
Expand Down Expand Up @@ -935,6 +1026,7 @@ mod test_update {
for _ in 0..200 {
let h = rng.gen_range(0..=4);
let (tree, _) = Tree::random_sum(&mut rng, h, false, false);
tree.validate();
tree.validate_sum();
}
}
Expand All @@ -953,6 +1045,7 @@ mod test_update {
p.update();

tree.fix_red(p);
tree.validate();
tree.validate_sum();
}
}
Expand All @@ -964,6 +1057,7 @@ mod test_update {
let h = rng.gen_range(1..=4);
let (mut tree, vios) = Tree::random_sum(&mut rng, h, false, true);
let vio = vios.black_vios[0];
tree.validate();
tree.validate_sum();

// Change this value to make sure `fix_black` updates all the proper ancestors.
Expand All @@ -972,7 +1066,39 @@ mod test_update {
}

tree.fix_black(vio);
tree.validate();
tree.validate_sum();
}
}

#[test]
fn test_join() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..200 {
let left_h = rng.gen_range(1..=4);
let right_h = rng.gen_range(1..=4);
let (left, _) = Tree::random_sum(&mut rng, left_h, false, false);
let (right, _) = Tree::random_sum(&mut rng, right_h, false, false);
left.validate();
right.validate();
left.validate_sum();
right.validate_sum();
let result = join(
left,
|_, _| {
Ptr::new(Node {
value: rng.gen_range(0..VALUE_LIM),
sum: 0,
color: if rng.gen() { Color::Red } else { Color::Black },
parent: None,
left: None,
right: None,
})
},
right,
);
result.validate();
result.validate_sum();
}
}
}
Loading

0 comments on commit 7f1d902

Please sign in to comment.