From ac1aad82234d80dc73c84dc7b51e8b047e56e87c Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 5 Nov 2024 19:39:57 -0500 Subject: [PATCH] use newtype pattern for `TreeIndex` --- crates/chia-datalayer/src/merkle.rs | 91 ++++++++++++++++++----------- 1 file changed, 56 insertions(+), 35 deletions(-) diff --git a/crates/chia-datalayer/src/merkle.rs b/crates/chia-datalayer/src/merkle.rs index 91c7cbfee..2684cd884 100644 --- a/crates/chia-datalayer/src/merkle.rs +++ b/crates/chia-datalayer/src/merkle.rs @@ -2,7 +2,7 @@ use pyo3::{ buffer::PyBuffer, exceptions::{PyAttributeError, PyValueError}, - pyclass, pymethods, PyResult, Python, + pyclass, pymethods, FromPyObject, IntoPy, PyObject, PyResult, Python, }; use clvmr::sha2::Sha256; @@ -14,7 +14,22 @@ use std::mem::size_of; use std::ops::Range; use thiserror::Error; -type TreeIndex = u32; +#[cfg_attr(feature = "py-bindings", derive(FromPyObject), pyo3(transparent))] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct TreeIndex(u32); + +impl IntoPy for TreeIndex { + fn into_py(self, py: Python<'_>) -> pyo3::PyObject { + self.0.into_py(py) + } +} + +impl std::fmt::Display for TreeIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + type Parent = Option; type Hash = [u8; 32]; // key and value ids are provided from outside of this code and are implemented as @@ -171,7 +186,7 @@ pub enum InsertLocation { Leaf { index: TreeIndex, side: Side }, } -const NULL_PARENT: TreeIndex = 0xffff_ffffu32; +const NULL_PARENT: TreeIndex = TreeIndex(0xffff_ffffu32); #[derive(Debug, PartialEq)] pub struct NodeMetadata { @@ -250,8 +265,8 @@ impl Node { hash: Self::hash_from_bytes(&blob), specific: match metadata.node_type { NodeType::Internal => NodeSpecific::Internal { - left: TreeIndex::from_be_bytes(blob[LEFT_RANGE].try_into().unwrap()), - right: TreeIndex::from_be_bytes(blob[RIGHT_RANGE].try_into().unwrap()), + left: TreeIndex(u32::from_be_bytes(blob[LEFT_RANGE].try_into().unwrap())), + right: TreeIndex(u32::from_be_bytes(blob[RIGHT_RANGE].try_into().unwrap())), }, NodeType::Leaf => NodeSpecific::Leaf { key: KvId::from_be_bytes(blob[KEY_RANGE].try_into().unwrap()), @@ -262,7 +277,7 @@ impl Node { } fn parent_from_bytes(blob: &DataBytes) -> Parent { - let parent_integer = TreeIndex::from_be_bytes(blob[PARENT_RANGE].try_into().unwrap()); + let parent_integer = TreeIndex(u32::from_be_bytes(blob[PARENT_RANGE].try_into().unwrap())); match parent_integer { NULL_PARENT => None, _ => Some(parent_integer), @@ -286,9 +301,9 @@ impl Node { Some(parent) => *parent, }; blob[HASH_RANGE].copy_from_slice(hash); - blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_be_bytes()); - blob[LEFT_RANGE].copy_from_slice(&left.to_be_bytes()); - blob[RIGHT_RANGE].copy_from_slice(&right.to_be_bytes()); + blob[PARENT_RANGE].copy_from_slice(&parent_integer.0.to_be_bytes()); + blob[LEFT_RANGE].copy_from_slice(&left.0.to_be_bytes()); + blob[RIGHT_RANGE].copy_from_slice(&right.0.to_be_bytes()); } Node { parent, @@ -300,7 +315,7 @@ impl Node { Some(parent) => *parent, }; blob[HASH_RANGE].copy_from_slice(hash); - blob[PARENT_RANGE].copy_from_slice(&parent_integer.to_be_bytes()); + blob[PARENT_RANGE].copy_from_slice(&parent_integer.0.to_be_bytes()); blob[KEY_RANGE].copy_from_slice(&key.to_be_bytes()); blob[VALUE_RANGE].copy_from_slice(&value.to_be_bytes()); } @@ -359,7 +374,7 @@ impl Node { } fn block_range(index: TreeIndex) -> Range { - let block_start = index as usize * BLOCK_SIZE; + let block_start = index.0 as usize * BLOCK_SIZE; block_start..block_start + BLOCK_SIZE } @@ -403,7 +418,7 @@ fn get_free_indexes_and_keys_values_indexes( let mut key_to_index: HashMap = HashMap::default(); for (index, block) in MerkleBlobLeftChildFirstIterator::new(blob) { - seen_indexes[index as usize] = true; + seen_indexes[index.0 as usize] = true; if let NodeSpecific::Leaf { key, .. } = block.node.specific { key_to_index.insert(key, index); @@ -413,7 +428,7 @@ fn get_free_indexes_and_keys_values_indexes( let mut free_indexes: HashSet = HashSet::new(); for (index, seen) in seen_indexes.iter().enumerate() { if !seen { - free_indexes.insert(index as TreeIndex); + free_indexes.insert(TreeIndex(index as u32)); } } @@ -561,7 +576,7 @@ impl MerkleBlob { return Err(Error::OldLeafUnexpectedlyNotALeaf); }; - node.parent = Some(0); + node.parent = Some(TreeIndex(0)); let nodes = [ ( @@ -570,7 +585,7 @@ impl MerkleBlob { Side::Right => left_index, }, Node { - parent: Some(0), + parent: Some(TreeIndex(0)), specific: NodeSpecific::Leaf { key: old_leaf_key, value: old_leaf_value, @@ -872,11 +887,11 @@ impl MerkleBlob { let Some(grandparent_index) = parent.parent else { sibling_block.node.parent = None; - self.insert_entry_to_blob(0, &sibling_block)?; + self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?; if let NodeSpecific::Internal { left, right } = sibling_block.node.specific { for child_index in [left, right] { - self.update_parent(child_index, Some(0))?; + self.update_parent(child_index, Some(TreeIndex(0)))?; } }; @@ -977,7 +992,7 @@ impl MerkleBlob { let total_count = leaf_count + internal_count + self.free_indexes.len(); let extend_index = self.extend_index(); assert_eq!( - total_count, extend_index as usize, + total_count, extend_index.0 as usize, "expected total node count {extend_index:?} found: {total_count:?}", ); assert_eq!(child_to_parent.len(), 0); @@ -1047,7 +1062,7 @@ impl MerkleBlob { } else { Side::Right }; - let mut next_index: TreeIndex = 0; + let mut next_index = TreeIndex(0); let mut node = self.get_node(next_index)?; loop { @@ -1080,7 +1095,7 @@ impl MerkleBlob { fn extend_index(&self) -> TreeIndex { let blob_length = self.blob.len(); - let index: TreeIndex = (blob_length / BLOCK_SIZE) as TreeIndex; + let index: TreeIndex = TreeIndex((blob_length / BLOCK_SIZE) as u32); let remainder = blob_length % BLOCK_SIZE; assert_eq!(remainder, 0, "blob length {blob_length:?} not a multiple of {BLOCK_SIZE:?}, remainder: {remainder:?}"); @@ -1378,7 +1393,7 @@ impl MerkleBlob { { use pyo3::conversion::IntoPy; use pyo3::types::PyListMethods; - list.append((index, node.into_py(py)))?; + list.append((index.into_py(py), node.into_py(py)))?; } Ok(list.into()) @@ -1391,7 +1406,7 @@ impl MerkleBlob { for (index, block) in MerkleBlobParentFirstIterator::new(&self.blob) { use pyo3::conversion::IntoPy; use pyo3::types::PyListMethods; - list.append((index, block.node.into_py(py)))?; + list.append((index.into_py(py), block.node.into_py(py)))?; } Ok(list.into()) @@ -1404,7 +1419,7 @@ impl MerkleBlob { #[pyo3(name = "get_root_hash")] pub fn py_get_root_hash(&self) -> PyResult> { - self.py_get_hash_at_index(0) + self.py_get_hash_at_index(TreeIndex(0)) } #[pyo3(name = "get_hash_at_index")] @@ -1463,7 +1478,7 @@ impl<'a> MerkleBlobLeftChildFirstIterator<'a> { if blob.len() / BLOCK_SIZE > 0 { deque.push_back(MerkleBlobLeftChildFirstIteratorItem { visited: false, - index: 0, + index: TreeIndex(0), }); } @@ -1516,7 +1531,7 @@ impl<'a> MerkleBlobParentFirstIterator<'a> { fn new(blob: &'a [u8]) -> Self { let mut deque = VecDeque::new(); if blob.len() / BLOCK_SIZE > 0 { - deque.push_back(0); + deque.push_back(TreeIndex(0)); } Self { blob, deque } @@ -1552,7 +1567,7 @@ impl<'a> MerkleBlobBreadthFirstIterator<'a> { fn new(blob: &'a [u8]) -> Self { let mut deque = VecDeque::new(); if blob.len() / BLOCK_SIZE > 0 { - deque.push_back(0); + deque.push_back(TreeIndex(0)); } Self { blob, deque } @@ -1673,7 +1688,7 @@ mod tests { #[rstest] fn test_get_lineage(small_blob: MerkleBlob) { - let lineage = small_blob.get_lineage_with_indexes(2).unwrap(); + let lineage = small_blob.get_lineage_with_indexes(TreeIndex(2)).unwrap(); for (_, node) in &lineage { println!("{node:?}"); } @@ -1683,8 +1698,8 @@ mod tests { } #[rstest] - #[case::right(0, 2, Side::Left)] - #[case::left(0xff, 1, Side::Right)] + #[case::right(0, TreeIndex(2), Side::Left)] + #[case::left(0xff, TreeIndex(1), Side::Right)] fn test_get_random_insert_location_by_seed( #[case] seed: u8, #[case] expected_index: TreeIndex, @@ -1868,7 +1883,10 @@ mod tests { let index = small_blob.key_to_index[&key]; small_blob.delete(key).unwrap(); - assert_eq!(small_blob.free_indexes, HashSet::from([index, 2])); + assert_eq!( + small_blob.free_indexes, + HashSet::from([index, TreeIndex(2)]) + ); } #[rstest] @@ -1879,7 +1897,7 @@ mod tests { small_blob.delete(key).unwrap(); open_dot(small_blob.to_dot().set_note("after delete")); - let expected = HashSet::from([1, 2]); + let expected = HashSet::from([TreeIndex(1), TreeIndex(2)]); assert_eq!(small_blob.free_indexes, expected); } @@ -1905,20 +1923,23 @@ mod tests { #[should_panic(expected = "unable to get sibling index from a leaf")] fn test_node_specific_sibling_index_panics_for_leaf() { let leaf = NodeSpecific::Leaf { key: 0, value: 0 }; - leaf.sibling_index(0); + leaf.sibling_index(TreeIndex(0)); } #[test] #[should_panic(expected = "index not a child: 2")] fn test_node_specific_sibling_index_panics_for_unknown_sibling() { - let node = NodeSpecific::Internal { left: 0, right: 1 }; - node.sibling_index(2); + let node = NodeSpecific::Internal { + left: TreeIndex(0), + right: TreeIndex(1), + }; + node.sibling_index(TreeIndex(2)); } #[rstest] fn test_get_free_indexes(small_blob: MerkleBlob) { let mut blob = small_blob.blob.clone(); - let expected_free_index = (blob.len() / BLOCK_SIZE) as TreeIndex; + let expected_free_index = TreeIndex((blob.len() / BLOCK_SIZE) as u32); blob.extend_from_slice(&[0; BLOCK_SIZE]); let (free_indexes, _) = get_free_indexes_and_keys_values_indexes(&blob); assert_eq!(free_indexes, HashSet::from([expected_free_index]));