Skip to content

Commit

Permalink
use newtype pattern for TreeIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky committed Nov 6, 2024
1 parent 3ffd27c commit ac1aad8
Showing 1 changed file with 56 additions and 35 deletions.
91 changes: 56 additions & 35 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use pyo3::{
buffer::PyBuffer,
exceptions::{PyAttributeError, PyValueError},
pyclass, pymethods, PyResult, Python,
pyclass, pymethods, FromPyObject, IntoPy, PyObject, PyResult, Python,
};

use clvmr::sha2::Sha256;
Expand All @@ -14,7 +14,22 @@ use std::mem::size_of;
use std::ops::Range;
use thiserror::Error;

type TreeIndex = u32;
#[cfg_attr(feature = "py-bindings", derive(FromPyObject), pyo3(transparent))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TreeIndex(u32);

impl IntoPy<PyObject> for TreeIndex {
fn into_py(self, py: Python<'_>) -> pyo3::PyObject {
self.0.into_py(py)
}
}

impl std::fmt::Display for TreeIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

type Parent = Option<TreeIndex>;
type Hash = [u8; 32];
// key and value ids are provided from outside of this code and are implemented as
Expand Down Expand Up @@ -171,7 +186,7 @@ pub enum InsertLocation {
Leaf { index: TreeIndex, side: Side },
}

const NULL_PARENT: TreeIndex = 0xffff_ffffu32;
const NULL_PARENT: TreeIndex = TreeIndex(0xffff_ffffu32);

#[derive(Debug, PartialEq)]
pub struct NodeMetadata {
Expand Down Expand Up @@ -250,8 +265,8 @@ impl Node {
hash: Self::hash_from_bytes(&blob),
specific: match metadata.node_type {
NodeType::Internal => NodeSpecific::Internal {
left: TreeIndex::from_be_bytes(blob[LEFT_RANGE].try_into().unwrap()),
right: TreeIndex::from_be_bytes(blob[RIGHT_RANGE].try_into().unwrap()),
left: TreeIndex(u32::from_be_bytes(blob[LEFT_RANGE].try_into().unwrap())),
right: TreeIndex(u32::from_be_bytes(blob[RIGHT_RANGE].try_into().unwrap())),
},
NodeType::Leaf => NodeSpecific::Leaf {
key: KvId::from_be_bytes(blob[KEY_RANGE].try_into().unwrap()),
Expand All @@ -262,7 +277,7 @@ impl Node {
}

fn parent_from_bytes(blob: &DataBytes) -> Parent {
let parent_integer = TreeIndex::from_be_bytes(blob[PARENT_RANGE].try_into().unwrap());
let parent_integer = TreeIndex(u32::from_be_bytes(blob[PARENT_RANGE].try_into().unwrap()));
match parent_integer {
NULL_PARENT => None,
_ => Some(parent_integer),
Expand All @@ -286,9 +301,9 @@ impl Node {
Some(parent) => *parent,
};
blob[HASH_RANGE].copy_from_slice(hash);
blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_be_bytes());
blob[LEFT_RANGE].copy_from_slice(&left.to_be_bytes());
blob[RIGHT_RANGE].copy_from_slice(&right.to_be_bytes());
blob[PARENT_RANGE].copy_from_slice(&parent_integer.0.to_be_bytes());
blob[LEFT_RANGE].copy_from_slice(&left.0.to_be_bytes());
blob[RIGHT_RANGE].copy_from_slice(&right.0.to_be_bytes());
}
Node {
parent,
Expand All @@ -300,7 +315,7 @@ impl Node {
Some(parent) => *parent,
};
blob[HASH_RANGE].copy_from_slice(hash);
blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_be_bytes());
blob[PARENT_RANGE].copy_from_slice(&parent_integer.0.to_be_bytes());
blob[KEY_RANGE].copy_from_slice(&key.to_be_bytes());
blob[VALUE_RANGE].copy_from_slice(&value.to_be_bytes());
}
Expand Down Expand Up @@ -359,7 +374,7 @@ impl Node {
}

fn block_range(index: TreeIndex) -> Range<usize> {
let block_start = index as usize * BLOCK_SIZE;
let block_start = index.0 as usize * BLOCK_SIZE;
block_start..block_start + BLOCK_SIZE
}

Expand Down Expand Up @@ -403,7 +418,7 @@ fn get_free_indexes_and_keys_values_indexes(
let mut key_to_index: HashMap<KvId, TreeIndex> = HashMap::default();

for (index, block) in MerkleBlobLeftChildFirstIterator::new(blob) {
seen_indexes[index as usize] = true;
seen_indexes[index.0 as usize] = true;

if let NodeSpecific::Leaf { key, .. } = block.node.specific {
key_to_index.insert(key, index);
Expand All @@ -413,7 +428,7 @@ fn get_free_indexes_and_keys_values_indexes(
let mut free_indexes: HashSet<TreeIndex> = HashSet::new();
for (index, seen) in seen_indexes.iter().enumerate() {
if !seen {
free_indexes.insert(index as TreeIndex);
free_indexes.insert(TreeIndex(index as u32));
}
}

Expand Down Expand Up @@ -561,7 +576,7 @@ impl MerkleBlob {
return Err(Error::OldLeafUnexpectedlyNotALeaf);
};

node.parent = Some(0);
node.parent = Some(TreeIndex(0));

let nodes = [
(
Expand All @@ -570,7 +585,7 @@ impl MerkleBlob {
Side::Right => left_index,
},
Node {
parent: Some(0),
parent: Some(TreeIndex(0)),
specific: NodeSpecific::Leaf {
key: old_leaf_key,
value: old_leaf_value,
Expand Down Expand Up @@ -872,11 +887,11 @@ impl MerkleBlob {

let Some(grandparent_index) = parent.parent else {
sibling_block.node.parent = None;
self.insert_entry_to_blob(0, &sibling_block)?;
self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?;

if let NodeSpecific::Internal { left, right } = sibling_block.node.specific {
for child_index in [left, right] {
self.update_parent(child_index, Some(0))?;
self.update_parent(child_index, Some(TreeIndex(0)))?;
}
};

Expand Down Expand Up @@ -977,7 +992,7 @@ impl MerkleBlob {
let total_count = leaf_count + internal_count + self.free_indexes.len();
let extend_index = self.extend_index();
assert_eq!(
total_count, extend_index as usize,
total_count, extend_index.0 as usize,
"expected total node count {extend_index:?} found: {total_count:?}",
);
assert_eq!(child_to_parent.len(), 0);
Expand Down Expand Up @@ -1047,7 +1062,7 @@ impl MerkleBlob {
} else {
Side::Right
};
let mut next_index: TreeIndex = 0;
let mut next_index = TreeIndex(0);
let mut node = self.get_node(next_index)?;

loop {
Expand Down Expand Up @@ -1080,7 +1095,7 @@ impl MerkleBlob {

fn extend_index(&self) -> TreeIndex {
let blob_length = self.blob.len();
let index: TreeIndex = (blob_length / BLOCK_SIZE) as TreeIndex;
let index: TreeIndex = TreeIndex((blob_length / BLOCK_SIZE) as u32);
let remainder = blob_length % BLOCK_SIZE;
assert_eq!(remainder, 0, "blob length {blob_length:?} not a multiple of {BLOCK_SIZE:?}, remainder: {remainder:?}");

Expand Down Expand Up @@ -1378,7 +1393,7 @@ impl MerkleBlob {
{
use pyo3::conversion::IntoPy;
use pyo3::types::PyListMethods;
list.append((index, node.into_py(py)))?;
list.append((index.into_py(py), node.into_py(py)))?;
}

Ok(list.into())
Expand All @@ -1391,7 +1406,7 @@ impl MerkleBlob {
for (index, block) in MerkleBlobParentFirstIterator::new(&self.blob) {
use pyo3::conversion::IntoPy;
use pyo3::types::PyListMethods;
list.append((index, block.node.into_py(py)))?;
list.append((index.into_py(py), block.node.into_py(py)))?;
}

Ok(list.into())
Expand All @@ -1404,7 +1419,7 @@ impl MerkleBlob {

#[pyo3(name = "get_root_hash")]
pub fn py_get_root_hash(&self) -> PyResult<Option<Hash>> {
self.py_get_hash_at_index(0)
self.py_get_hash_at_index(TreeIndex(0))
}

#[pyo3(name = "get_hash_at_index")]
Expand Down Expand Up @@ -1463,7 +1478,7 @@ impl<'a> MerkleBlobLeftChildFirstIterator<'a> {
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(MerkleBlobLeftChildFirstIteratorItem {
visited: false,
index: 0,
index: TreeIndex(0),
});
}

Expand Down Expand Up @@ -1516,7 +1531,7 @@ impl<'a> MerkleBlobParentFirstIterator<'a> {
fn new(blob: &'a [u8]) -> Self {
let mut deque = VecDeque::new();
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(0);
deque.push_back(TreeIndex(0));
}

Self { blob, deque }
Expand Down Expand Up @@ -1552,7 +1567,7 @@ impl<'a> MerkleBlobBreadthFirstIterator<'a> {
fn new(blob: &'a [u8]) -> Self {
let mut deque = VecDeque::new();
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(0);
deque.push_back(TreeIndex(0));
}

Self { blob, deque }
Expand Down Expand Up @@ -1673,7 +1688,7 @@ mod tests {

#[rstest]
fn test_get_lineage(small_blob: MerkleBlob) {
let lineage = small_blob.get_lineage_with_indexes(2).unwrap();
let lineage = small_blob.get_lineage_with_indexes(TreeIndex(2)).unwrap();
for (_, node) in &lineage {
println!("{node:?}");
}
Expand All @@ -1683,8 +1698,8 @@ mod tests {
}

#[rstest]
#[case::right(0, 2, Side::Left)]
#[case::left(0xff, 1, Side::Right)]
#[case::right(0, TreeIndex(2), Side::Left)]
#[case::left(0xff, TreeIndex(1), Side::Right)]
fn test_get_random_insert_location_by_seed(
#[case] seed: u8,
#[case] expected_index: TreeIndex,
Expand Down Expand Up @@ -1868,7 +1883,10 @@ mod tests {
let index = small_blob.key_to_index[&key];
small_blob.delete(key).unwrap();

assert_eq!(small_blob.free_indexes, HashSet::from([index, 2]));
assert_eq!(
small_blob.free_indexes,
HashSet::from([index, TreeIndex(2)])
);
}

#[rstest]
Expand All @@ -1879,7 +1897,7 @@ mod tests {
small_blob.delete(key).unwrap();
open_dot(small_blob.to_dot().set_note("after delete"));

let expected = HashSet::from([1, 2]);
let expected = HashSet::from([TreeIndex(1), TreeIndex(2)]);
assert_eq!(small_blob.free_indexes, expected);
}

Expand All @@ -1905,20 +1923,23 @@ mod tests {
#[should_panic(expected = "unable to get sibling index from a leaf")]
fn test_node_specific_sibling_index_panics_for_leaf() {
let leaf = NodeSpecific::Leaf { key: 0, value: 0 };
leaf.sibling_index(0);
leaf.sibling_index(TreeIndex(0));
}

#[test]
#[should_panic(expected = "index not a child: 2")]
fn test_node_specific_sibling_index_panics_for_unknown_sibling() {
let node = NodeSpecific::Internal { left: 0, right: 1 };
node.sibling_index(2);
let node = NodeSpecific::Internal {
left: TreeIndex(0),
right: TreeIndex(1),
};
node.sibling_index(TreeIndex(2));
}

#[rstest]
fn test_get_free_indexes(small_blob: MerkleBlob) {
let mut blob = small_blob.blob.clone();
let expected_free_index = (blob.len() / BLOCK_SIZE) as TreeIndex;
let expected_free_index = TreeIndex((blob.len() / BLOCK_SIZE) as u32);
blob.extend_from_slice(&[0; BLOCK_SIZE]);
let (free_indexes, _) = get_free_indexes_and_keys_values_indexes(&blob);
assert_eq!(free_indexes, HashSet::from([expected_free_index]));
Expand Down

0 comments on commit ac1aad8

Please sign in to comment.