diff --git a/CHANGELOG.md b/CHANGELOG.md index ff66e1ed..eb91062c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +## 0.5.0 (2023-05-26) +* Implemented `TieredSmt` (#152, #153). +* Implemented ability to extract a subset of a `MerkleStore` (#151). +* Cleaned up `SimpleSmt` interface (#149). +* Decoupled hashing and padding of peaks in `Mmr` (#148). +* Added `inner_nodes()` to `MerkleStore` (#146). + ## 0.4.0 (2023-04-21) - Exported `MmrProof` from the crate (#137). diff --git a/Cargo.toml b/Cargo.toml index 9b1555f4..f29a1bb8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,12 @@ [package] name = "miden-crypto" -version = "0.4.0" +version = "0.5.0" description = "Miden Cryptographic primitives" authors = ["miden contributors"] readme = "README.md" license = "MIT" repository = "https://github.com/0xPolygonMiden/crypto" -documentation = "https://docs.rs/miden-crypto/0.4.0" +documentation = "https://docs.rs/miden-crypto/0.5.0" categories = ["cryptography", "no-std"] keywords = ["miden", "crypto", "hash", "merkle"] edition = "2021" @@ -35,6 +35,6 @@ winter_math = { version = "0.6", package = "winter-math", default-features = fal winter_utils = { version = "0.6", package = "winter-utils", default-features = false } [dev-dependencies] -criterion = { version = "0.4", features = ["html_reports"] } +criterion = { version = "0.5", features = ["html_reports"] } proptest = "1.1.0" rand_utils = { version = "0.6", package = "winter-rand-utils" } diff --git a/benches/smt.rs b/benches/smt.rs index 3c63c044..44e3ea5d 100644 --- a/benches/smt.rs +++ b/benches/smt.rs @@ -18,8 +18,8 @@ fn smt_rpo(c: &mut Criterion) { (i, word) }) .collect(); - let tree = SimpleSmt::new(depth).unwrap().with_leaves(entries).unwrap(); - trees.push(tree); + let tree = SimpleSmt::with_leaves(depth, entries).unwrap(); + trees.push((tree, count)); } } @@ -29,10 +29,9 @@ fn smt_rpo(c: &mut Criterion) { let mut insert = c.benchmark_group(format!("smt update_leaf")); - for tree in trees.iter_mut() { + for (tree, count) in trees.iter_mut() { let depth = tree.depth(); - let count = tree.leaves_count() as u64; - let key = count >> 2; + let key = *count >> 2; insert.bench_with_input( format!("simple smt(depth:{depth},count:{count})"), &(key, leaf), @@ -48,10 +47,9 @@ fn smt_rpo(c: &mut Criterion) { let mut path = c.benchmark_group(format!("smt get_leaf_path")); - for tree in trees.iter_mut() { + for (tree, count) in trees.iter_mut() { let depth = tree.depth(); - let count = tree.leaves_count() as u64; - let key = count >> 2; + let key = *count >> 2; path.bench_with_input( format!("simple smt(depth:{depth},count:{count})"), &key, diff --git a/benches/store.rs b/benches/store.rs index 3ad79c47..aaa98060 100644 --- a/benches/store.rs +++ b/benches/store.rs @@ -104,10 +104,7 @@ fn get_leaf_simplesmt(c: &mut Criterion) { .enumerate() .map(|(c, v)| (c.try_into().unwrap(), v.into())) .collect::>(); - let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH) - .unwrap() - .with_leaves(smt_leaves.clone()) - .unwrap(); + let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap(); let store = MerkleStore::from(&smt); let depth = smt.depth(); let root = smt.root(); @@ -215,10 +212,7 @@ fn get_node_simplesmt(c: &mut Criterion) { .enumerate() .map(|(c, v)| (c.try_into().unwrap(), v.into())) .collect::>(); - let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH) - .unwrap() - .with_leaves(smt_leaves.clone()) - .unwrap(); + let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap(); let store = MerkleStore::from(&smt); let root = smt.root(); let half_depth = smt.depth() / 2; @@ -292,10 +286,7 @@ fn get_leaf_path_simplesmt(c: &mut Criterion) { .enumerate() .map(|(c, v)| (c.try_into().unwrap(), v.into())) .collect::>(); - let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH) - .unwrap() - .with_leaves(smt_leaves.clone()) - .unwrap(); + let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap(); let store = MerkleStore::from(&smt); let depth = smt.depth(); let root = smt.root(); @@ -361,7 +352,7 @@ fn new(c: &mut Criterion) { .map(|(c, v)| (c.try_into().unwrap(), v.into())) .collect::>() }, - |l| black_box(SimpleSmt::new(SimpleSmt::MAX_DEPTH).unwrap().with_leaves(l)), + |l| black_box(SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l)), BatchSize::SmallInput, ) }); @@ -376,7 +367,7 @@ fn new(c: &mut Criterion) { .collect::>() }, |l| { - let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH).unwrap().with_leaves(l).unwrap(); + let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l).unwrap(); black_box(MerkleStore::from(&smt)); }, BatchSize::SmallInput, @@ -442,10 +433,7 @@ fn update_leaf_simplesmt(c: &mut Criterion) { .enumerate() .map(|(c, v)| (c.try_into().unwrap(), v.into())) .collect::>(); - let mut smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH) - .unwrap() - .with_leaves(smt_leaves.clone()) - .unwrap(); + let mut smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap(); let mut store = MerkleStore::from(&smt); let depth = smt.depth(); let root = smt.root(); diff --git a/src/hash/rpo/digest.rs b/src/hash/rpo/digest.rs index 2edfd27d..0e6c3109 100644 --- a/src/hash/rpo/digest.rs +++ b/src/hash/rpo/digest.rs @@ -2,7 +2,7 @@ use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO}; use crate::utils::{ string::String, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, }; -use core::{cmp::Ordering, ops::Deref}; +use core::{cmp::Ordering, fmt::Display, ops::Deref}; // DIGEST TRAIT IMPLEMENTATIONS // ================================================================================================ @@ -85,6 +85,28 @@ impl From for [Felt; DIGEST_SIZE] { } } +impl From<&RpoDigest> for [u64; DIGEST_SIZE] { + fn from(value: &RpoDigest) -> Self { + [ + value.0[0].as_int(), + value.0[1].as_int(), + value.0[2].as_int(), + value.0[3].as_int(), + ] + } +} + +impl From for [u64; DIGEST_SIZE] { + fn from(value: RpoDigest) -> Self { + [ + value.0[0].as_int(), + value.0[1].as_int(), + value.0[2].as_int(), + value.0[3].as_int(), + ] + } +} + impl From<&RpoDigest> for [u8; 32] { fn from(value: &RpoDigest) -> Self { value.as_bytes() @@ -134,6 +156,15 @@ impl PartialOrd for RpoDigest { } } +impl Display for RpoDigest { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + for byte in self.as_bytes() { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + // TESTS // ================================================================================================ diff --git a/src/merkle/empty_roots.rs b/src/merkle/empty_roots.rs index 7f8c4a2a..b1b0b30f 100644 --- a/src/merkle/empty_roots.rs +++ b/src/merkle/empty_roots.rs @@ -1,6 +1,12 @@ -use super::{Felt, RpoDigest, WORD_SIZE, ZERO}; +use super::{Felt, RpoDigest, Word, WORD_SIZE, ZERO}; use core::slice; +// CONSTANTS +// ================================================================================================ + +/// A word consisting of 4 ZERO elements. +pub const EMPTY_WORD: Word = [ZERO; WORD_SIZE]; + // EMPTY NODES SUBTREES // ================================================================================================ @@ -1570,7 +1576,7 @@ fn all_depths_opens_to_zero() { assert_eq!(depth as usize + 1, subtree.len()); // assert the opening is zero - let initial = RpoDigest::new([ZERO; WORD_SIZE]); + let initial = RpoDigest::new(EMPTY_WORD); assert_eq!(initial, subtree.remove(0)); // compute every node of the path manually and compare with the output diff --git a/src/merkle/index.rs b/src/merkle/index.rs index b1c73892..564d9816 100644 --- a/src/merkle/index.rs +++ b/src/merkle/index.rs @@ -1,4 +1,5 @@ use super::{Felt, MerkleError, RpoDigest, StarkField}; +use core::fmt::Display; // NODE INDEX // ================================================================================================ @@ -40,6 +41,12 @@ impl NodeIndex { } } + /// Creates a new node index without checking its validity. + pub const fn new_unchecked(depth: u8, value: u64) -> Self { + debug_assert!((64 - value.leading_zeros()) <= depth as u32); + Self { depth, value } + } + /// Creates a new node index for testing purposes. /// /// # Panics @@ -67,12 +74,26 @@ impl NodeIndex { Self { depth: 0, value: 0 } } - /// Computes the value of the sibling of the current node. - pub fn sibling(mut self) -> Self { + /// Computes sibling index of the current node. + pub const fn sibling(mut self) -> Self { self.value ^= 1; self } + /// Returns left child index of the current node. + pub const fn left_child(mut self) -> Self { + self.depth += 1; + self.value <<= 1; + self + } + + /// Returns right child index of the current node. + pub const fn right_child(mut self) -> Self { + self.depth += 1; + self.value = (self.value << 1) + 1; + self + } + // PROVIDERS // -------------------------------------------------------------------------------------------- @@ -117,11 +138,26 @@ impl NodeIndex { // STATE MUTATORS // -------------------------------------------------------------------------------------------- - /// Traverse one level towards the root, decrementing the depth by `1`. - pub fn move_up(&mut self) -> &mut Self { + /// Traverses one level towards the root, decrementing the depth by `1`. + pub fn move_up(&mut self) { self.depth = self.depth.saturating_sub(1); self.value >>= 1; - self + } + + /// Traverses towards the root until the specified depth is reached. + /// + /// Assumes that the specified depth is smaller than the current depth. + pub fn move_up_to(&mut self, depth: u8) { + debug_assert!(depth < self.depth); + let delta = self.depth.saturating_sub(depth); + self.depth = self.depth.saturating_sub(delta); + self.value >>= delta as u32; + } +} + +impl Display for NodeIndex { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "depth={}, value={}", self.depth, self.value) } } diff --git a/src/merkle/merkle_tree.rs b/src/merkle/merkle_tree.rs index 42f289aa..45487398 100644 --- a/src/merkle/merkle_tree.rs +++ b/src/merkle/merkle_tree.rs @@ -114,6 +114,28 @@ impl MerkleTree { Ok(path.into()) } + // ITERATORS + // -------------------------------------------------------------------------------------------- + + /// Returns an iterator over the leaves of this [MerkleTree]. + pub fn leaves(&self) -> impl Iterator { + let leaves_start = self.nodes.len() / 2; + self.nodes.iter().skip(leaves_start).enumerate().map(|(i, v)| (i as u64, v)) + } + + /// Returns n iterator over every inner node of this [MerkleTree]. + /// + /// The iterator order is unspecified. + pub fn inner_nodes(&self) -> InnerNodeIterator { + InnerNodeIterator { + nodes: &self.nodes, + index: 1, // index 0 is just padding, start at 1 + } + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + /// Replaces the leaf at the specified index with the provided value. /// /// # Errors @@ -149,16 +171,6 @@ impl MerkleTree { Ok(()) } - - /// Returns n iterator over every inner node of this [MerkleTree]. - /// - /// The iterator order is unspecified. - pub fn inner_nodes(&self) -> InnerNodeIterator<'_> { - InnerNodeIterator { - nodes: &self.nodes, - index: 1, // index 0 is just padding, start at 1 - } - } } // ITERATORS diff --git a/src/merkle/mmr/accumulator.rs b/src/merkle/mmr/accumulator.rs index 610999ae..ce0ee496 100644 --- a/src/merkle/mmr/accumulator.rs +++ b/src/merkle/mmr/accumulator.rs @@ -1,8 +1,4 @@ -use super::{ - super::Vec, - super::{WORD_SIZE, ZERO}, - MmrProof, Rpo256, Word, -}; +use super::{super::Vec, super::ZERO, Felt, MmrProof, Rpo256, Word}; #[derive(Debug, Clone, PartialEq)] pub struct MmrPeaks { @@ -35,25 +31,49 @@ pub struct MmrPeaks { impl MmrPeaks { /// Hashes the peaks. /// - /// The hashing is optimized to work with the Miden VM, the procedure will: - /// - /// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO padding. - /// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of - /// hashing. + /// The procedure will: + /// - Flatten and pad the peaks to a vector of Felts. + /// - Hash the vector of Felts. pub fn hash_peaks(&self) -> Word { - let mut copy = self.peaks.clone(); - - if copy.len() < 16 { - copy.resize(16, [ZERO; WORD_SIZE]) - } else if copy.len() % 2 == 1 { - copy.push([ZERO; WORD_SIZE]) - } - - Rpo256::hash_elements(©.as_slice().concat()).into() + Rpo256::hash_elements(&self.flatten_and_pad_peaks()).into() } pub fn verify(&self, value: Word, opening: MmrProof) -> bool { let root = &self.peaks[opening.peak_index()]; opening.merkle_path.verify(opening.relative_pos() as u64, value, root) } + + /// Flattens and pads the peaks to make hashing inside of the Miden VM easier. + /// + /// The procedure will: + /// - Flatten the vector of Words into a vector of Felts. + /// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO + /// padding. + /// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of + /// hashing. + pub fn flatten_and_pad_peaks(&self) -> Vec { + let num_peaks = self.peaks.len(); + + // To achieve the padding rules above we calculate the length of the final vector. + // This is calculated as the number of field elements. Each peak is 4 field elements. + // The length is calculated as follows: + // - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires + // 64 field elements. + // - If there are more than 16 peaks and the number of peaks is odd, the data is padded to + // an even number of peaks and as such requires `(num_peaks + 1) * 4` field elements. + // - If there are more than 16 peaks and the number of peaks is even, the data is not padded + // and as such requires `num_peaks * 4` field elements. + let len = if num_peaks < 16 { + 64 + } else if num_peaks % 2 == 1 { + (num_peaks + 1) * 4 + } else { + num_peaks * 4 + }; + + let mut elements = Vec::with_capacity(len); + elements.extend_from_slice(&self.peaks.as_slice().concat()); + elements.resize(len, ZERO); + elements + } } diff --git a/src/merkle/mmr/mod.rs b/src/merkle/mmr/mod.rs index d8903ca5..118bb120 100644 --- a/src/merkle/mmr/mod.rs +++ b/src/merkle/mmr/mod.rs @@ -6,7 +6,7 @@ mod proof; #[cfg(test)] mod tests; -use super::{Rpo256, Word}; +use super::{Felt, Rpo256, Word}; // REEXPORTS // ================================================================================================ diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index aca3dde3..631b9601 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -1,6 +1,6 @@ use super::{ hash::rpo::{Rpo256, RpoDigest}, - utils::collections::{vec, BTreeMap, Vec}, + utils::collections::{vec, BTreeMap, BTreeSet, Vec}, Felt, StarkField, Word, WORD_SIZE, ZERO, }; use core::fmt; @@ -10,6 +10,7 @@ use core::fmt; mod empty_roots; pub use empty_roots::EmptySubtreeRoots; +use empty_roots::EMPTY_WORD; mod index; pub use index::NodeIndex; @@ -26,6 +27,9 @@ pub use path_set::MerklePathSet; mod simple_smt; pub use simple_smt::SimpleSmt; +mod tiered_smt; +pub use tiered_smt::TieredSmt; + mod mmr; pub use mmr::{Mmr, MmrPeaks, MmrProof}; @@ -43,13 +47,15 @@ pub enum MerkleError { ConflictingRoots(Vec), DepthTooSmall(u8), DepthTooBig(u64), - NodeNotInStore(Word, NodeIndex), - NumLeavesNotPowerOfTwo(usize), + DuplicateValuesForIndex(u64), + DuplicateValuesForKey(RpoDigest), InvalidIndex { depth: u8, value: u64 }, InvalidDepth { expected: u8, provided: u8 }, InvalidPath(MerklePath), - InvalidEntriesCount(usize, usize), - NodeNotInSet(u64), + InvalidNumEntries(usize, usize), + NodeNotInSet(NodeIndex), + NodeNotInStore(Word, NodeIndex), + NumLeavesNotPowerOfTwo(usize), RootNotInStore(Word), } @@ -60,9 +66,8 @@ impl fmt::Display for MerkleError { ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"), DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"), DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"), - NumLeavesNotPowerOfTwo(leaves) => { - write!(f, "the leaves count {leaves} is not a power of 2") - } + DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"), + DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"), InvalidIndex{ depth, value} => write!( f, "the index value {value} is not valid for the depth {depth}" @@ -72,9 +77,12 @@ impl fmt::Display for MerkleError { "the provided depth {provided} is not valid for {expected}" ), InvalidPath(_path) => write!(f, "the provided path is not valid"), - InvalidEntriesCount(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"), - NodeNotInSet(index) => write!(f, "the node indexed by {index} is not in the set"), - NodeNotInStore(hash, index) => write!(f, "the node {:?} indexed by {} and depth {} is not in the store", hash, index.value(), index.depth(),), + InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"), + NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"), + NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"), + NumLeavesNotPowerOfTwo(leaves) => { + write!(f, "the leaves count {leaves} is not a power of 2") + } RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root), } } diff --git a/src/merkle/path_set.rs b/src/merkle/path_set.rs index 653fab1c..ed945fb8 100644 --- a/src/merkle/path_set.rs +++ b/src/merkle/path_set.rs @@ -73,7 +73,7 @@ impl MerklePathSet { let path_key = index.value() - parity; self.paths .get(&path_key) - .ok_or(MerkleError::NodeNotInSet(path_key)) + .ok_or(MerkleError::NodeNotInSet(index)) .map(|path| path[parity as usize]) } @@ -104,11 +104,8 @@ impl MerklePathSet { let parity = index.value() & 1; let path_key = index.value() - parity; - let mut path = self - .paths - .get(&path_key) - .cloned() - .ok_or(MerkleError::NodeNotInSet(index.value()))?; + let mut path = + self.paths.get(&path_key).cloned().ok_or(MerkleError::NodeNotInSet(index))?; path.remove(parity as usize); Ok(path) } @@ -200,7 +197,7 @@ impl MerklePathSet { let path_key = index.value() - parity; let path = match self.paths.get_mut(&path_key) { Some(path) => path, - None => return Err(MerkleError::NodeNotInSet(base_index_value)), + None => return Err(MerkleError::NodeNotInSet(index)), }; // Fill old_hashes vector ----------------------------------------------------------------- diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index 1f29fe00..9d253160 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -1,6 +1,6 @@ use super::{ - BTreeMap, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, - RpoDigest, Vec, Word, + BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, + Rpo256, RpoDigest, Vec, Word, EMPTY_WORD, }; #[cfg(test)] @@ -10,6 +10,7 @@ mod tests; // ================================================================================================ /// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction. +/// /// The root of the tree is recomputed on each new leaf update. #[derive(Debug, Clone, PartialEq, Eq)] pub struct SimpleSmt { @@ -20,18 +21,6 @@ pub struct SimpleSmt { empty_hashes: Vec, } -#[derive(Debug, Default, Clone, PartialEq, Eq)] -struct BranchNode { - left: RpoDigest, - right: RpoDigest, -} - -impl BranchNode { - fn parent(&self) -> RpoDigest { - Rpo256::merge(&[self.left, self.right]) - } -} - impl SimpleSmt { // CONSTANTS // -------------------------------------------------------------------------------------------- @@ -45,7 +34,12 @@ impl SimpleSmt { // CONSTRUCTORS // -------------------------------------------------------------------------------------------- - /// Creates a new simple SMT with the provided depth. + /// Returns a new [SimpleSmt] instantiated with the specified depth. + /// + /// All leaves in the returned tree are set to [ZERO; 4]. + /// + /// # Errors + /// Returns an error if the depth is 0 or is greater than 64. pub fn new(depth: u8) -> Result { // validate the range of the depth. if depth < Self::MIN_DEPTH { @@ -66,36 +60,47 @@ impl SimpleSmt { }) } - /// Appends the provided entries as leaves of the tree. + /// Returns a new [SimpleSmt] instantiated with the specified depth and with leaves + /// set as specified by the provided entries. /// - /// # Errors + /// All leaves omitted from the entries list are set to [ZERO; 4]. /// - /// The function will fail if the provided entries count exceed the maximum tree capacity, that - /// is `2^{depth}`. - pub fn with_leaves(mut self, entries: R) -> Result + /// # Errors + /// Returns an error if: + /// - If the depth is 0 or is greater than 64. + /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}. + /// - The provided entries contain multiple values for the same key. + pub fn with_leaves(depth: u8, entries: R) -> Result where R: IntoIterator, I: Iterator + ExactSizeIterator, { - // check if the leaves count will fit the depth setup - let mut entries = entries.into_iter(); - let max = 1 << self.depth.min(63); + // create an empty tree + let mut tree = Self::new(depth)?; + + // check if the number of leaves can be accommodated by the tree's depth; we use a min + // depth of 63 because we consider passing in a vector of size 2^64 infeasible. + let entries = entries.into_iter(); + let max = 1 << tree.depth.min(63); if entries.len() > max { - return Err(MerkleError::InvalidEntriesCount(max, entries.len())); + return Err(MerkleError::InvalidNumEntries(max, entries.len())); } - // append leaves and return - entries.try_for_each(|(key, leaf)| self.insert_leaf(key, leaf))?; - Ok(self) - } - - /// Replaces the internal empty digests used when a given depth doesn't contain a node. - pub fn with_empty_subtrees(mut self, hashes: I) -> Self - where - I: IntoIterator, - { - self.replace_empty_subtrees(hashes.into_iter().collect()); - self + // append leaves to the tree returning an error if a duplicate entry for the same key + // is found + let mut empty_entries = BTreeSet::new(); + for (key, value) in entries { + let old_value = tree.update_leaf(key, value)?; + if old_value != EMPTY_WORD || empty_entries.contains(&key) { + return Err(MerkleError::DuplicateValuesForIndex(key)); + } + // if we've processed an empty entry, add the key to the set of empty entry keys, and + // if this key was already in the set, return an error + if value == EMPTY_WORD && !empty_entries.insert(key) { + return Err(MerkleError::DuplicateValuesForIndex(key)); + } + } + Ok(tree) } // PUBLIC ACCESSORS @@ -111,40 +116,43 @@ impl SimpleSmt { self.depth } - // PROVIDERS - // -------------------------------------------------------------------------------------------- - - /// Returns the set count of the keys of the leaves. - pub fn leaves_count(&self) -> usize { - self.leaves.len() - } - /// Returns a node at the specified index. /// /// # Errors - /// Returns an error if: - /// * The specified depth is greater than the depth of the tree. + /// Returns an error if the specified index has depth set to 0 or the depth is greater than + /// the depth of this Merkle tree. pub fn get_node(&self, index: NodeIndex) -> Result { if index.is_root() { Err(MerkleError::DepthTooSmall(index.depth())) } else if index.depth() > self.depth() { Err(MerkleError::DepthTooBig(index.depth() as u64)) } else if index.depth() == self.depth() { - self.get_leaf_node(index.value()) - .or_else(|| self.empty_hashes.get(index.depth() as usize).copied().map(Word::from)) - .ok_or(MerkleError::NodeNotInSet(index.value())) + // the lookup in empty_hashes could fail only if empty_hashes were not built correctly + // by the constructor as we check the depth of the lookup above. + Ok(self + .get_leaf_node(index.value()) + .unwrap_or_else(|| self.empty_hashes[index.depth() as usize].into())) } else { - let branch_node = self.get_branch_node(&index); - Ok(Rpo256::merge(&[branch_node.left, branch_node.right]).into()) + Ok(self.get_branch_node(&index).parent().into()) } } - /// Returns a Merkle path from the node at the specified key to the root. The node itself is - /// not included in the path. + /// Returns a value of the leaf at the specified index. /// /// # Errors - /// Returns an error if: - /// * The specified depth is greater than the depth of the tree. + /// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}. + pub fn get_leaf(&self, index: u64) -> Result { + let index = NodeIndex::new(self.depth, index)?; + self.get_node(index) + } + + /// Returns a Merkle path from the node at the specified index to the root. + /// + /// The node itself is not included in the path. + /// + /// # Errors + /// Returns an error if the specified index has depth set to 0 or the depth is greater than + /// the depth of this Merkle tree. pub fn get_path(&self, mut index: NodeIndex) -> Result { if index.is_root() { return Err(MerkleError::DepthTooSmall(index.depth())); @@ -163,18 +171,26 @@ impl SimpleSmt { Ok(path.into()) } - /// Return a Merkle path from the leaf at the specified key to the root. The leaf itself is not - /// included in the path. + /// Return a Merkle path from the leaf at the specified index to the root. + /// + /// The leaf itself is not included in the path. /// /// # Errors - /// Returns an error if: - /// * The specified key does not exist as a leaf node. - pub fn get_leaf_path(&self, key: u64) -> Result { - let index = NodeIndex::new(self.depth(), key)?; + /// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}. + pub fn get_leaf_path(&self, index: u64) -> Result { + let index = NodeIndex::new(self.depth(), index)?; self.get_path(index) } - /// Iterator over the inner nodes of the [SimpleSmt]. + // ITERATORS + // -------------------------------------------------------------------------------------------- + + /// Returns an iterator over the leaves of this [SimpleSmt]. + pub fn leaves(&self) -> impl Iterator { + self.leaves.iter().map(|(i, w)| (*i, w)) + } + + /// Returns an iterator over the inner nodes of this Merkle tree. pub fn inner_nodes(&self) -> impl Iterator + '_ { self.branches.values().map(|e| InnerNodeInfo { value: e.parent().into(), @@ -186,27 +202,21 @@ impl SimpleSmt { // STATE MUTATORS // -------------------------------------------------------------------------------------------- - /// Replaces the leaf located at the specified key, and recomputes hashes by walking up the - /// tree. + /// Updates value of the leaf at the specified index returning the old leaf value. + /// + /// This also recomputes all hashes between the leaf and the root, updating the root itself. /// /// # Errors - /// Returns an error if the specified key is not a valid leaf index for this tree. - pub fn update_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> { - let index = NodeIndex::new(self.depth(), key)?; - if !self.check_leaf_node_exists(key) { - return Err(MerkleError::NodeNotInSet(index.value())); - } - self.insert_leaf(key, value)?; - - Ok(()) - } + /// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}. + pub fn update_leaf(&mut self, index: u64, value: Word) -> Result { + let old_value = self.insert_leaf_node(index, value).unwrap_or(EMPTY_WORD); - /// Inserts a leaf located at the specified key, and recomputes hashes by walking up the tree - pub fn insert_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> { - self.insert_leaf_node(key, value); + // if the old value and new value are the same, there is nothing to update + if value == old_value { + return Ok(value); + } - // TODO consider using a map `index |-> word` instead of `index |-> (word, word)` - let mut index = NodeIndex::new(self.depth(), key)?; + let mut index = NodeIndex::new(self.depth(), index)?; let mut value = RpoDigest::from(value); for _ in 0..index.depth() { let is_right = index.is_value_odd(); @@ -217,26 +227,18 @@ impl SimpleSmt { value = Rpo256::merge(&[left, right]); } self.root = value.into(); - Ok(()) + Ok(old_value) } // HELPER METHODS // -------------------------------------------------------------------------------------------- - fn replace_empty_subtrees(&mut self, hashes: Vec) { - self.empty_hashes = hashes; - } - - fn check_leaf_node_exists(&self, key: u64) -> bool { - self.leaves.contains_key(&key) - } - fn get_leaf_node(&self, key: u64) -> Option { self.leaves.get(&key).copied() } - fn insert_leaf_node(&mut self, key: u64, node: Word) { - self.leaves.insert(key, node); + fn insert_leaf_node(&mut self, key: u64, node: Word) -> Option { + self.leaves.insert(key, node) } fn get_branch_node(&self, index: &NodeIndex) -> BranchNode { @@ -254,3 +256,18 @@ impl SimpleSmt { self.branches.insert(index, branch); } } + +// BRANCH NODE +// ================================================================================================ + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct BranchNode { + left: RpoDigest, + right: RpoDigest, +} + +impl BranchNode { + fn parent(&self) -> RpoDigest { + Rpo256::merge(&[self.left, self.right]) + } +} diff --git a/src/merkle/simple_smt/tests.rs b/src/merkle/simple_smt/tests.rs index 582e0030..0174e15f 100644 --- a/src/merkle/simple_smt/tests.rs +++ b/src/merkle/simple_smt/tests.rs @@ -1,9 +1,10 @@ use super::{ super::{int_to_node, InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt}, - NodeIndex, Rpo256, Vec, Word, + NodeIndex, Rpo256, Vec, Word, EMPTY_WORD, }; -use proptest::prelude::*; -use rand_utils::prng_array; + +// TEST DATA +// ================================================================================================ const KEYS4: [u64; 4] = [0, 1, 2, 3]; const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; @@ -23,25 +24,17 @@ const VALUES8: [Word; 8] = [ const ZERO_VALUES8: [Word; 8] = [int_to_node(0); 8]; +// TESTS +// ================================================================================================ + #[test] fn build_empty_tree() { + // tree of depth 3 let smt = SimpleSmt::new(3).unwrap(); let mt = MerkleTree::new(ZERO_VALUES8.to_vec()).unwrap(); assert_eq!(mt.root(), smt.root()); } -#[test] -fn empty_digests_are_consistent() { - let depth = 5; - let root = SimpleSmt::new(depth).unwrap().root(); - let computed: [RpoDigest; 2] = (0..depth).fold([Default::default(); 2], |state, _| { - let digest = Rpo256::merge(&state); - [digest; 2] - }); - - assert_eq!(Word::from(computed[0]), root); -} - #[test] fn build_sparse_tree() { let mut smt = SimpleSmt::new(3).unwrap(); @@ -51,80 +44,59 @@ fn build_sparse_tree() { let key = 6; let new_node = int_to_node(7); values[key as usize] = new_node; - smt.insert_leaf(key, new_node).expect("Failed to insert leaf"); + let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf"); let mt2 = MerkleTree::new(values.clone()).unwrap(); assert_eq!(mt2.root(), smt.root()); assert_eq!( mt2.get_path(NodeIndex::make(3, 6)).unwrap(), smt.get_path(NodeIndex::make(3, 6)).unwrap() ); + assert_eq!(old_value, EMPTY_WORD); // insert second value at distinct leaf branch let key = 2; let new_node = int_to_node(3); values[key as usize] = new_node; - smt.insert_leaf(key, new_node).expect("Failed to insert leaf"); + let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf"); let mt3 = MerkleTree::new(values).unwrap(); assert_eq!(mt3.root(), smt.root()); assert_eq!( mt3.get_path(NodeIndex::make(3, 2)).unwrap(), smt.get_path(NodeIndex::make(3, 2)).unwrap() ); + assert_eq!(old_value, EMPTY_WORD); } #[test] -fn build_full_tree() { - let tree = SimpleSmt::new(2) - .unwrap() - .with_leaves(KEYS4.into_iter().zip(VALUES4.into_iter())) - .unwrap(); +fn test_depth2_tree() { + let tree = SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(VALUES4.into_iter())).unwrap(); + // check internal structure let (root, node2, node3) = compute_internal_nodes(); assert_eq!(root, tree.root()); assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap()); assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap()); -} - -#[test] -fn get_values() { - let tree = SimpleSmt::new(2) - .unwrap() - .with_leaves(KEYS4.into_iter().zip(VALUES4.into_iter())) - .unwrap(); - // check depth 2 + // check get_node() assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap()); assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap()); assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap()); assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap()); -} -#[test] -fn get_path() { - let tree = SimpleSmt::new(2) - .unwrap() - .with_leaves(KEYS4.into_iter().zip(VALUES4.into_iter())) - .unwrap(); - - let (_, node2, node3) = compute_internal_nodes(); - - // check depth 2 + // check get_path(): depth 2 assert_eq!(vec![VALUES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap()); assert_eq!(vec![VALUES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap()); assert_eq!(vec![VALUES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap()); assert_eq!(vec![VALUES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap()); - // check depth 1 + // check get_path(): depth 1 assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap()); assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap()); } #[test] -fn test_parent_node_iterator() -> Result<(), MerkleError> { - let tree = SimpleSmt::new(2) - .unwrap() - .with_leaves(KEYS4.into_iter().zip(VALUES4.into_iter())) - .unwrap(); +fn test_inner_node_iterator() -> Result<(), MerkleError> { + let tree = SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(VALUES4.into_iter())).unwrap(); // check depth 2 assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap()); @@ -166,35 +138,28 @@ fn test_parent_node_iterator() -> Result<(), MerkleError> { #[test] fn update_leaf() { - let mut tree = SimpleSmt::new(3) - .unwrap() - .with_leaves(KEYS8.into_iter().zip(VALUES8.into_iter())) - .unwrap(); + let mut tree = SimpleSmt::with_leaves(3, KEYS8.into_iter().zip(VALUES8.into_iter())).unwrap(); // update one value let key = 3; let new_node = int_to_node(9); let mut expected_values = VALUES8.to_vec(); expected_values[key] = new_node; - let expected_tree = SimpleSmt::new(3) - .unwrap() - .with_leaves(KEYS8.into_iter().zip(expected_values.clone().into_iter())) - .unwrap(); + let expected_tree = MerkleTree::new(expected_values.clone()).unwrap(); - tree.update_leaf(key as u64, new_node).unwrap(); - assert_eq!(expected_tree.root, tree.root); + let old_leaf = tree.update_leaf(key as u64, new_node).unwrap(); + assert_eq!(expected_tree.root(), tree.root); + assert_eq!(old_leaf, VALUES8[key]); // update another value let key = 6; let new_node = int_to_node(10); expected_values[key] = new_node; - let expected_tree = SimpleSmt::new(3) - .unwrap() - .with_leaves(KEYS8.into_iter().zip(expected_values.into_iter())) - .unwrap(); + let expected_tree = MerkleTree::new(expected_values.clone()).unwrap(); - tree.update_leaf(key as u64, new_node).unwrap(); - assert_eq!(expected_tree.root, tree.root); + let old_leaf = tree.update_leaf(key as u64, new_node).unwrap(); + assert_eq!(expected_tree.root(), tree.root); + assert_eq!(old_leaf, VALUES8[key]); } #[test] @@ -226,7 +191,7 @@ fn small_tree_opening_is_consistent() { let depth = 3; let entries = vec![(0, a), (1, b), (4, c), (7, d)]; - let tree = SimpleSmt::new(depth).unwrap().with_leaves(entries).unwrap(); + let tree = SimpleSmt::with_leaves(depth, entries).unwrap(); assert_eq!(tree.root(), Word::from(k)); @@ -250,56 +215,30 @@ fn small_tree_opening_is_consistent() { } } -proptest! { - #[test] - fn arbitrary_openings_single_leaf( - depth in SimpleSmt::MIN_DEPTH..SimpleSmt::MAX_DEPTH, - key in prop::num::u64::ANY, - leaf in prop::num::u64::ANY, - ) { - let mut tree = SimpleSmt::new(depth).unwrap(); - - let key = key % (1 << depth as u64); - let leaf = int_to_node(leaf); - - tree.insert_leaf(key, leaf.into()).unwrap(); - tree.get_leaf_path(key).unwrap(); - - // traverse to root, fetching all paths - for d in 1..depth { - let k = key >> (depth - d); - tree.get_path(NodeIndex::make(d, k)).unwrap(); - } - } +#[test] +fn fail_on_duplicates() { + let entries = [(1_u64, int_to_node(1)), (5, int_to_node(2)), (1_u64, int_to_node(3))]; + let smt = SimpleSmt::with_leaves(64, entries); + assert!(smt.is_err()); + + let entries = [(1_u64, int_to_node(0)), (5, int_to_node(2)), (1_u64, int_to_node(0))]; + let smt = SimpleSmt::with_leaves(64, entries); + assert!(smt.is_err()); + + let entries = [(1_u64, int_to_node(0)), (5, int_to_node(2)), (1_u64, int_to_node(1))]; + let smt = SimpleSmt::with_leaves(64, entries); + assert!(smt.is_err()); + + let entries = [(1_u64, int_to_node(1)), (5, int_to_node(2)), (1_u64, int_to_node(0))]; + let smt = SimpleSmt::with_leaves(64, entries); + assert!(smt.is_err()); +} - #[test] - fn arbitrary_openings_multiple_leaves( - depth in SimpleSmt::MIN_DEPTH..SimpleSmt::MAX_DEPTH, - count in 2u8..10u8, - ref seed in any::<[u8; 32]>() - ) { - let mut tree = SimpleSmt::new(depth).unwrap(); - let mut seed = *seed; - let leaves = (1 << depth) - 1; - - for _ in 0..count { - seed = prng_array(seed); - - let mut key = [0u8; 8]; - let mut leaf = [0u8; 8]; - - key.copy_from_slice(&seed[..8]); - leaf.copy_from_slice(&seed[8..16]); - - let key = u64::from_le_bytes(key); - let key = key % leaves; - let leaf = u64::from_le_bytes(leaf); - let leaf = int_to_node(leaf); - - tree.insert_leaf(key, leaf).unwrap(); - tree.get_leaf_path(key).unwrap(); - } - } +#[test] +fn with_no_duplicates_empty_node() { + let entries = [(1_u64, int_to_node(0)), (5, int_to_node(2))]; + let smt = SimpleSmt::with_leaves(64, entries); + assert!(smt.is_ok()); } // HELPER FUNCTIONS diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index 96ca54a4..d4c1ffba 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -1,9 +1,9 @@ -use super::mmr::Mmr; use super::{ - BTreeMap, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerklePathSet, MerkleTree, - NodeIndex, RootPath, Rpo256, RpoDigest, SimpleSmt, ValuePath, Vec, Word, + mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerklePathSet, + MerkleTree, NodeIndex, RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, ValuePath, Vec, Word, }; use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use core::borrow::Borrow; #[cfg(test)] mod tests; @@ -14,7 +14,7 @@ pub struct Node { right: RpoDigest, } -/// An in-memory data store for Merkle-lized data. +/// An in-memory data store for Merkelized data. /// /// This is a in memory data store for Merkle trees, this store allows all the nodes of multiple /// trees to live as long as necessary and without duplication, this allows the implementation of @@ -152,7 +152,6 @@ impl MerkleStore { /// The path starts at the sibling of the target leaf. /// /// # Errors - /// /// This method can return the following errors: /// - `RootNotInStore` if the `root` is not present in the store. /// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the store. @@ -257,6 +256,35 @@ impl MerkleStore { Ok(tree_depth) } + // DATA EXTRACTORS + // -------------------------------------------------------------------------------------------- + + /// Returns a subset of this Merkle store such that the returned Merkle store contains all + /// nodes which are descendants of the specified roots. + /// + /// The roots for which no descendants exist in this Merkle store are ignored. + pub fn subset(&self, roots: I) -> MerkleStore + where + I: Iterator, + R: Borrow, + { + let mut store = MerkleStore::new(); + for root in roots { + let root = RpoDigest::from(*root.borrow()); + store.clone_tree_from(root, self); + } + store + } + + /// Iterator over the inner nodes of the [MerkleStore]. + pub fn inner_nodes(&self) -> impl Iterator + '_ { + self.nodes.iter().map(|(r, n)| InnerNodeInfo { + value: r.into(), + left: n.left.into(), + right: n.right.into(), + }) + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -364,6 +392,24 @@ impl MerkleStore { Ok(parent.into()) } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Recursively clones a tree with the specified root from the specified source into self. + /// + /// If the source store does not contain a tree with the specified root, this is a noop. + fn clone_tree_from(&mut self, root: RpoDigest, source: &Self) { + // process the node only if it is in the source + if let Some(node) = source.nodes.get(&root) { + // if the node has already been inserted, no need to process it further as all of its + // descendants should be already cloned from the source store + if matches!(self.nodes.insert(root, *node), None) { + self.clone_tree_from(node.left, source); + self.clone_tree_from(node.right, source); + } + } + } } // CONVERSIONS @@ -393,6 +439,14 @@ impl From<&Mmr> for MerkleStore { } } +impl From<&TieredSmt> for MerkleStore { + fn from(value: &TieredSmt) -> Self { + let mut store = MerkleStore::new(); + store.extend(value.inner_nodes()); + store + } +} + impl FromIterator for MerkleStore { fn from_iter>(iter: T) -> Self { let mut store = MerkleStore::new(); diff --git a/src/merkle/store/tests.rs b/src/merkle/store/tests.rs index 02e12b2a..0bdfa4f2 100644 --- a/src/merkle/store/tests.rs +++ b/src/merkle/store/tests.rs @@ -1,29 +1,48 @@ -use super::*; +use super::{ + super::EMPTY_WORD, Deserializable, EmptySubtreeRoots, MerkleError, MerklePath, MerkleStore, + NodeIndex, RpoDigest, Serializable, +}; use crate::{ hash::rpo::Rpo256, merkle::{int_to_node, MerklePathSet, MerkleTree, SimpleSmt}, - Felt, Word, WORD_SIZE, ZERO, + Felt, Word, WORD_SIZE, }; #[cfg(feature = "std")] use std::error::Error; +// TEST DATA +// ================================================================================================ + const KEYS4: [u64; 4] = [0, 1, 2, 3]; -const LEAVES4: [Word; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)]; -const EMPTY: Word = [ZERO; WORD_SIZE]; +const VALUES4: [Word; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)]; + +const VALUES8: [Word; 8] = [ + int_to_node(1), + int_to_node(2), + int_to_node(3), + int_to_node(4), + int_to_node(5), + int_to_node(6), + int_to_node(7), + int_to_node(8), +]; + +// TESTS +// ================================================================================================ #[test] fn test_root_not_in_store() -> Result<(), MerkleError> { - let mtree = MerkleTree::new(LEAVES4.to_vec())?; + let mtree = MerkleTree::new(VALUES4.to_vec())?; let store = MerkleStore::from(&mtree); assert_eq!( - store.get_node(LEAVES4[0], NodeIndex::make(mtree.depth(), 0)), - Err(MerkleError::RootNotInStore(LEAVES4[0])), + store.get_node(VALUES4[0], NodeIndex::make(mtree.depth(), 0)), + Err(MerkleError::RootNotInStore(VALUES4[0])), "Leaf 0 is not a root" ); assert_eq!( - store.get_path(LEAVES4[0], NodeIndex::make(mtree.depth(), 0)), - Err(MerkleError::RootNotInStore(LEAVES4[0])), + store.get_path(VALUES4[0], NodeIndex::make(mtree.depth(), 0)), + Err(MerkleError::RootNotInStore(VALUES4[0])), "Leaf 0 is not a root" ); @@ -32,33 +51,33 @@ fn test_root_not_in_store() -> Result<(), MerkleError> { #[test] fn test_merkle_tree() -> Result<(), MerkleError> { - let mtree = MerkleTree::new(LEAVES4.to_vec())?; + let mtree = MerkleTree::new(VALUES4.to_vec())?; let store = MerkleStore::from(&mtree); - // STORE LEAVES ARE CORRECT ============================================================== + // STORE LEAVES ARE CORRECT ------------------------------------------------------------------- // checks the leaves in the store corresponds to the expected values assert_eq!( store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)), - Ok(LEAVES4[0]), + Ok(VALUES4[0]), "node 0 must be in the tree" ); assert_eq!( store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)), - Ok(LEAVES4[1]), + Ok(VALUES4[1]), "node 1 must be in the tree" ); assert_eq!( store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)), - Ok(LEAVES4[2]), + Ok(VALUES4[2]), "node 2 must be in the tree" ); assert_eq!( store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)), - Ok(LEAVES4[3]), + Ok(VALUES4[3]), "node 3 must be in the tree" ); - // STORE LEAVES MATCH TREE =============================================================== + // STORE LEAVES MATCH TREE -------------------------------------------------------------------- // sanity check the values returned by the store and the tree assert_eq!( mtree.get_node(NodeIndex::make(mtree.depth(), 0)), @@ -85,7 +104,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> { // assert the merkle path returned by the store is the same as the one in the tree let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap(); assert_eq!( - LEAVES4[0], result.value, + VALUES4[0], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( @@ -96,7 +115,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> { let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap(); assert_eq!( - LEAVES4[1], result.value, + VALUES4[1], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( @@ -107,7 +126,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> { let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap(); assert_eq!( - LEAVES4[2], result.value, + VALUES4[2], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( @@ -118,7 +137,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> { let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap(); assert_eq!( - LEAVES4[3], result.value, + VALUES4[3], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( @@ -133,7 +152,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> { #[test] fn test_empty_roots() { let store = MerkleStore::default(); - let mut root = RpoDigest::new(EMPTY); + let mut root = RpoDigest::new(EMPTY_WORD); for depth in 0..255 { root = Rpo256::merge(&[root; 2]); @@ -157,13 +176,13 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> { let index = NodeIndex::make(depth, 0); let store_path = store.get_path(smt.root(), index)?; let smt_path = smt.get_path(index)?; - assert_eq!(store_path.value, EMPTY, "the leaf of an empty tree is always ZERO"); + assert_eq!(store_path.value, EMPTY_WORD, "the leaf of an empty tree is always ZERO"); assert_eq!( store_path.path, smt_path, "the returned merkle path does not match the computed values" ); assert_eq!( - store_path.path.compute_root(depth.into(), EMPTY).unwrap(), + store_path.path.compute_root(depth.into(), EMPTY_WORD).unwrap(), smt.root(), "computed root from the path must match the empty tree root" ); @@ -174,7 +193,7 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> { #[test] fn test_get_invalid_node() { - let mtree = MerkleTree::new(LEAVES4.to_vec()).expect("creating a merkle tree must work"); + let mtree = MerkleTree::new(VALUES4.to_vec()).expect("creating a merkle tree must work"); let store = MerkleStore::from(&mtree); let _ = store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)); } @@ -183,10 +202,7 @@ fn test_get_invalid_node() { fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> { let keys2: [u64; 2] = [0, 1]; let leaves2: [Word; 2] = [int_to_node(1), int_to_node(2)]; - let smt = SimpleSmt::new(1) - .unwrap() - .with_leaves(keys2.into_iter().zip(leaves2.into_iter())) - .unwrap(); + let smt = SimpleSmt::with_leaves(1, keys2.into_iter().zip(leaves2.into_iter())).unwrap(); let store = MerkleStore::from(&smt); let idx = NodeIndex::make(1, 0); @@ -202,10 +218,9 @@ fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> { #[test] fn test_sparse_merkle_tree() -> Result<(), MerkleError> { - let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH) - .unwrap() - .with_leaves(KEYS4.into_iter().zip(LEAVES4.into_iter())) - .unwrap(); + let smt = + SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, KEYS4.into_iter().zip(VALUES4.into_iter())) + .unwrap(); let store = MerkleStore::from(&smt); @@ -213,27 +228,27 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> { // checks the leaves in the store corresponds to the expected values assert_eq!( store.get_node(smt.root(), NodeIndex::make(smt.depth(), 0)), - Ok(LEAVES4[0]), + Ok(VALUES4[0]), "node 0 must be in the tree" ); assert_eq!( store.get_node(smt.root(), NodeIndex::make(smt.depth(), 1)), - Ok(LEAVES4[1]), + Ok(VALUES4[1]), "node 1 must be in the tree" ); assert_eq!( store.get_node(smt.root(), NodeIndex::make(smt.depth(), 2)), - Ok(LEAVES4[2]), + Ok(VALUES4[2]), "node 2 must be in the tree" ); assert_eq!( store.get_node(smt.root(), NodeIndex::make(smt.depth(), 3)), - Ok(LEAVES4[3]), + Ok(VALUES4[3]), "node 3 must be in the tree" ); assert_eq!( store.get_node(smt.root(), NodeIndex::make(smt.depth(), 4)), - Ok(EMPTY), + Ok(EMPTY_WORD), "unmodified node 4 must be ZERO" ); @@ -269,7 +284,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> { // assert the merkle path returned by the store is the same as the one in the tree let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 0)).unwrap(); assert_eq!( - LEAVES4[0], result.value, + VALUES4[0], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( @@ -280,7 +295,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> { let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 1)).unwrap(); assert_eq!( - LEAVES4[1], result.value, + VALUES4[1], result.value, "Value for merkle path at index 1 must match leaf value" ); assert_eq!( @@ -291,7 +306,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> { let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 2)).unwrap(); assert_eq!( - LEAVES4[2], result.value, + VALUES4[2], result.value, "Value for merkle path at index 2 must match leaf value" ); assert_eq!( @@ -302,7 +317,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> { let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 3)).unwrap(); assert_eq!( - LEAVES4[3], result.value, + VALUES4[3], result.value, "Value for merkle path at index 3 must match leaf value" ); assert_eq!( @@ -312,7 +327,10 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> { ); let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 4)).unwrap(); - assert_eq!(EMPTY, result.value, "Value for merkle path at index 4 must match leaf value"); + assert_eq!( + EMPTY_WORD, result.value, + "Value for merkle path at index 4 must match leaf value" + ); assert_eq!( smt.get_path(NodeIndex::make(smt.depth(), 4)), Ok(result.path), @@ -324,7 +342,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> { #[test] fn test_add_merkle_paths() -> Result<(), MerkleError> { - let mtree = MerkleTree::new(LEAVES4.to_vec())?; + let mtree = MerkleTree::new(VALUES4.to_vec())?; let i0 = 0; let p0 = mtree.get_path(NodeIndex::make(2, i0)).unwrap(); @@ -339,10 +357,10 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> { let p3 = mtree.get_path(NodeIndex::make(2, i3)).unwrap(); let paths = [ - (i0, LEAVES4[i0 as usize], p0), - (i1, LEAVES4[i1 as usize], p1), - (i2, LEAVES4[i2 as usize], p2), - (i3, LEAVES4[i3 as usize], p3), + (i0, VALUES4[i0 as usize], p0), + (i1, VALUES4[i1 as usize], p1), + (i2, VALUES4[i2 as usize], p2), + (i3, VALUES4[i3 as usize], p3), ]; let mut store = MerkleStore::default(); @@ -355,22 +373,22 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> { // checks the leaves in the store corresponds to the expected values assert_eq!( store.get_node(set.root(), NodeIndex::make(set.depth(), 0)), - Ok(LEAVES4[0]), + Ok(VALUES4[0]), "node 0 must be in the set" ); assert_eq!( store.get_node(set.root(), NodeIndex::make(set.depth(), 1)), - Ok(LEAVES4[1]), + Ok(VALUES4[1]), "node 1 must be in the set" ); assert_eq!( store.get_node(set.root(), NodeIndex::make(set.depth(), 2)), - Ok(LEAVES4[2]), + Ok(VALUES4[2]), "node 2 must be in the set" ); assert_eq!( store.get_node(set.root(), NodeIndex::make(set.depth(), 3)), - Ok(LEAVES4[3]), + Ok(VALUES4[3]), "node 3 must be in the set" ); @@ -401,7 +419,7 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> { // assert the merkle path returned by the store is the same as the one in the set let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 0)).unwrap(); assert_eq!( - LEAVES4[0], result.value, + VALUES4[0], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( @@ -412,7 +430,7 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> { let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 1)).unwrap(); assert_eq!( - LEAVES4[1], result.value, + VALUES4[1], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( @@ -423,7 +441,7 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> { let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 2)).unwrap(); assert_eq!( - LEAVES4[2], result.value, + VALUES4[2], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( @@ -434,7 +452,7 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> { let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 3)).unwrap(); assert_eq!( - LEAVES4[3], result.value, + VALUES4[3], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( @@ -502,7 +520,7 @@ fn store_path_opens_from_leaf() { #[test] fn test_set_node() -> Result<(), MerkleError> { - let mtree = MerkleTree::new(LEAVES4.to_vec())?; + let mtree = MerkleTree::new(VALUES4.to_vec())?; let mut store = MerkleStore::from(&mtree); let value = int_to_node(42); let index = NodeIndex::make(mtree.depth(), 0); @@ -514,7 +532,7 @@ fn test_set_node() -> Result<(), MerkleError> { #[test] fn test_constructors() -> Result<(), MerkleError> { - let mtree = MerkleTree::new(LEAVES4.to_vec())?; + let mtree = MerkleTree::new(VALUES4.to_vec())?; let store = MerkleStore::from(&mtree); let depth = mtree.depth(); @@ -526,10 +544,7 @@ fn test_constructors() -> Result<(), MerkleError> { } let depth = 32; - let smt = SimpleSmt::new(depth) - .unwrap() - .with_leaves(KEYS4.into_iter().zip(LEAVES4.into_iter())) - .unwrap(); + let smt = SimpleSmt::with_leaves(depth, KEYS4.into_iter().zip(VALUES4.into_iter())).unwrap(); let store = MerkleStore::from(&smt); let depth = smt.depth(); @@ -541,20 +556,20 @@ fn test_constructors() -> Result<(), MerkleError> { let d = 2; let paths = [ - (0, LEAVES4[0], mtree.get_path(NodeIndex::make(d, 0)).unwrap()), - (1, LEAVES4[1], mtree.get_path(NodeIndex::make(d, 1)).unwrap()), - (2, LEAVES4[2], mtree.get_path(NodeIndex::make(d, 2)).unwrap()), - (3, LEAVES4[3], mtree.get_path(NodeIndex::make(d, 3)).unwrap()), + (0, VALUES4[0], mtree.get_path(NodeIndex::make(d, 0)).unwrap()), + (1, VALUES4[1], mtree.get_path(NodeIndex::make(d, 1)).unwrap()), + (2, VALUES4[2], mtree.get_path(NodeIndex::make(d, 2)).unwrap()), + (3, VALUES4[3], mtree.get_path(NodeIndex::make(d, 3)).unwrap()), ]; let mut store1 = MerkleStore::default(); store1.add_merkle_paths(paths.clone())?; let mut store2 = MerkleStore::default(); - store2.add_merkle_path(0, LEAVES4[0], mtree.get_path(NodeIndex::make(d, 0))?)?; - store2.add_merkle_path(1, LEAVES4[1], mtree.get_path(NodeIndex::make(d, 1))?)?; - store2.add_merkle_path(2, LEAVES4[2], mtree.get_path(NodeIndex::make(d, 2))?)?; - store2.add_merkle_path(3, LEAVES4[3], mtree.get_path(NodeIndex::make(d, 3))?)?; + store2.add_merkle_path(0, VALUES4[0], mtree.get_path(NodeIndex::make(d, 0))?)?; + store2.add_merkle_path(1, VALUES4[1], mtree.get_path(NodeIndex::make(d, 1))?)?; + store2.add_merkle_path(2, VALUES4[2], mtree.get_path(NodeIndex::make(d, 2))?)?; + store2.add_merkle_path(3, VALUES4[3], mtree.get_path(NodeIndex::make(d, 3))?)?; let set = MerklePathSet::new(d).with_paths(paths).unwrap(); for key in [0, 1, 2, 3] { @@ -718,10 +733,67 @@ fn get_leaf_depth_works_with_depth_8() { assert_eq!(Err(MerkleError::DepthTooBig(9)), store.get_leaf_depth(root, 8, a)); } +// SUBSET EXTRACTION +// ================================================================================================ + +#[test] +fn mstore_subset() { + // add a Merkle tree of depth 3 to the store + let mtree = MerkleTree::new(VALUES8.to_vec()).unwrap(); + let mut store = MerkleStore::default(); + let empty_store_num_nodes = store.nodes.len(); + store.extend(mtree.inner_nodes()); + + // build 3 subtrees contained within the above Merkle tree; note that subtree2 is a subset + // of subtree1 + let subtree1 = MerkleTree::new(VALUES8[..4].to_vec()).unwrap(); + let subtree2 = MerkleTree::new(VALUES8[2..4].to_vec()).unwrap(); + let subtree3 = MerkleTree::new(VALUES8[6..].to_vec()).unwrap(); + + // --- extract all 3 subtrees --------------------------------------------- + + let substore = store.subset([subtree1.root(), subtree2.root(), subtree3.root()].iter()); + + // number of nodes should increase by 4: 3 nodes form subtree1 and 1 node from subtree3 + assert_eq!(substore.nodes.len(), empty_store_num_nodes + 4); + + // make sure paths that all subtrees are in the store + check_mstore_subtree(&substore, &subtree1); + check_mstore_subtree(&substore, &subtree2); + check_mstore_subtree(&substore, &subtree3); + + // --- extract subtrees 1 and 3 ------------------------------------------- + // this should give the same result as above as subtree2 is nested withing subtree1 + + let substore = store.subset([subtree1.root(), subtree3.root()].iter()); + + // number of nodes should increase by 4: 3 nodes form subtree1 and 1 node from subtree3 + assert_eq!(substore.nodes.len(), empty_store_num_nodes + 4); + + // make sure paths that all subtrees are in the store + check_mstore_subtree(&substore, &subtree1); + check_mstore_subtree(&substore, &subtree2); + check_mstore_subtree(&substore, &subtree3); +} + +fn check_mstore_subtree(store: &MerkleStore, subtree: &MerkleTree) { + for (i, value) in subtree.leaves() { + let index = NodeIndex::new(subtree.depth(), i).unwrap(); + let path1 = store.get_path(subtree.root(), index).unwrap(); + assert_eq!(&path1.value, value); + + let path2 = subtree.get_path(index).unwrap(); + assert_eq!(path1.path, path2); + } +} + +// SERIALIZATION +// ================================================================================================ + #[cfg(feature = "std")] #[test] fn test_serialization() -> Result<(), Box> { - let mtree = MerkleTree::new(LEAVES4.to_vec())?; + let mtree = MerkleTree::new(VALUES4.to_vec())?; let store = MerkleStore::from(&mtree); let decoded = MerkleStore::read_from_bytes(&store.to_bytes()).expect("deserialization failed"); assert_eq!(store, decoded); diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs new file mode 100644 index 00000000..eaf1cd87 --- /dev/null +++ b/src/merkle/tiered_smt/mod.rs @@ -0,0 +1,482 @@ +use super::{ + BTreeMap, BTreeSet, EmptySubtreeRoots, Felt, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, + Rpo256, RpoDigest, StarkField, Vec, Word, EMPTY_WORD, ZERO, +}; +use core::cmp; + +#[cfg(test)] +mod tests; + +// TIERED SPARSE MERKLE TREE +// ================================================================================================ + +/// Tiered (compacted) Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and +/// values are represented by 4 field elements. +/// +/// Leaves in the tree can exist only on specific depths called "tiers". These depths are: 16, 32, +/// 48, and 64. Initially, when a tree is empty, it is equivalent to an empty Sparse Merkle tree +/// of depth 64 (i.e., leaves at depth 64 are set to [ZERO; 4]). As non-empty values are inserted +/// into the tree they are added to the first available tier. +/// +/// For example, when the first key-value is inserted, it will be stored in a node at depth 16 +/// such that the first 16 bits of the key determine the position of the node at depth 16. If +/// another value with a key sharing the same 16-bit prefix is inserted, both values move into +/// the next tier (depth 32). This process is repeated until values end up at tier 64. If multiple +/// values have keys with a common 64-bit prefix, such key-value pairs are stored in a sorted list +/// at the last tier (depth = 64). +/// +/// To differentiate between internal and leaf nodes, node values are computed as follows: +/// - Internal nodes: hash(left_child, right_child). +/// - Leaf node at depths 16, 32, or 64: hash(rem_key, value, domain=depth). +/// - Leaf node at depth 64: hash([rem_key_0, value_0, ..., rem_key_n, value_n, domain=64]). +/// +/// Where rem_key is computed by replacing d most significant bits of the key with zeros where d +/// is depth (i.e., for a leaf at depth 16, we replace 16 most significant bits of the key with 0). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TieredSmt { + root: RpoDigest, + nodes: BTreeMap, + upper_leaves: BTreeMap, // node_index |-> key map + bottom_leaves: BTreeMap, // leaves of depth 64 + values: BTreeMap, +} + +impl TieredSmt { + // CONSTANTS + // -------------------------------------------------------------------------------------------- + + /// The number of levels between tiers. + const TIER_SIZE: u8 = 16; + + /// Depths at which leaves can exist in a tiered SMT. + const TIER_DEPTHS: [u8; 4] = [16, 32, 48, 64]; + + /// Maximum node depth. This is also the bottom tier of the tree. + const MAX_DEPTH: u8 = 64; + + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + /// Returns a new [TieredSmt] instantiated with the specified key-value pairs. + /// + /// # Errors + /// Returns an error if the provided entries contain multiple values for the same key. + pub fn with_leaves(entries: R) -> Result + where + R: IntoIterator, + I: Iterator + ExactSizeIterator, + { + // create an empty tree + let mut tree = Self::default(); + + // append leaves to the tree returning an error if a duplicate entry for the same key + // is found + let mut empty_entries = BTreeSet::new(); + for (key, value) in entries { + let old_value = tree.insert(key, value); + if old_value != EMPTY_WORD || empty_entries.contains(&key) { + return Err(MerkleError::DuplicateValuesForKey(key)); + } + // if we've processed an empty entry, add the key to the set of empty entry keys, and + // if this key was already in the set, return an error + if value == EMPTY_WORD && !empty_entries.insert(key) { + return Err(MerkleError::DuplicateValuesForKey(key)); + } + } + Ok(tree) + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the root of this Merkle tree. + pub const fn root(&self) -> RpoDigest { + self.root + } + + /// Returns a node at the specified index. + /// + /// # Errors + /// Returns an error if: + /// - The specified index depth is 0 or greater than 64. + /// - The node with the specified index does not exists in the Merkle tree. This is possible + /// when a leaf node with the same index prefix exists at a tier higher than the requested + /// node. + pub fn get_node(&self, index: NodeIndex) -> Result { + self.validate_node_access(index)?; + Ok(self.get_node_unchecked(&index)) + } + + /// Returns a Merkle path from the node at the specified index to the root. + /// + /// The node itself is not included in the path. + /// + /// # Errors + /// Returns an error if: + /// - The specified index depth is 0 or greater than 64. + /// - The node with the specified index does not exists in the Merkle tree. This is possible + /// when a leaf node with the same index prefix exists at a tier higher than the node to + /// which the path is requested. + pub fn get_path(&self, mut index: NodeIndex) -> Result { + self.validate_node_access(index)?; + + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let node = self.get_node_unchecked(&index.sibling()); + path.push(node.into()); + index.move_up(); + } + + Ok(path.into()) + } + + /// Returns the value associated with the specified key. + /// + /// If nothing was inserted into this tree for the specified key, [ZERO; 4] is returned. + pub fn get_value(&self, key: RpoDigest) -> Word { + match self.values.get(&key) { + Some(value) => *value, + None => EMPTY_WORD, + } + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + + /// Inserts the provided value into the tree under the specified key and returns the value + /// previously stored under this key. + /// + /// If the value for the specified key was not previously set, [ZERO; 4] is returned. + pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word { + // insert the value into the key-value map, and if nothing has changed, return + let old_value = self.values.insert(key, value).unwrap_or(EMPTY_WORD); + if old_value == value { + return old_value; + } + + // determine the index for the value node; this index could have 3 different meanings: + // - it points to a root of an empty subtree (excluding depth = 64); in this case, we can + // replace the node with the value node immediately. + // - it points to a node at the bottom tier (i.e., depth = 64); in this case, we need to + // process bottom-tier insertion which will be handled by insert_node(). + // - it points to a leaf node; this node could be a node with the same key or a different + // key with a common prefix; in the latter case, we'll need to move the leaf to a lower + // tier; for this scenario the `leaf_key` will contain the key of the leaf node + let (mut index, leaf_key) = self.get_insert_location(&key); + + // if the returned index points to a leaf, and this leaf is for a different key, we need + // to move the leaf to a lower tier + if let Some(other_key) = leaf_key { + if other_key != key { + // determine how far down the tree should we move the existing leaf + let common_prefix_len = get_common_prefix_tier(&key, &other_key); + let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH); + + // move the leaf to the new location; this requires first removing the existing + // index, re-computing node value, and inserting the node at a new location + let other_index = key_to_index(&other_key, depth); + let other_value = *self.values.get(&other_key).expect("no value for other key"); + self.upper_leaves.remove(&index).expect("other node key not in map"); + self.insert_node(other_index, other_key, other_value); + + // the new leaf also needs to move down to the same tier + index = key_to_index(&key, depth); + } + } + + // insert the node and return the old value + self.insert_node(index, key, value); + old_value + } + + // ITERATORS + // -------------------------------------------------------------------------------------------- + + /// Returns an iterator over all inner nodes of this [TieredSmt] (i.e., nodes not at depths 16 + /// 32, 48, or 64). + /// + /// The iterator order is unspecified. + pub fn inner_nodes(&self) -> impl Iterator + '_ { + self.nodes.iter().filter_map(|(index, node)| { + if is_inner_node(index) { + Some(InnerNodeInfo { + value: node.into(), + left: self.get_node_unchecked(&index.left_child()).into(), + right: self.get_node_unchecked(&index.right_child()).into(), + }) + } else { + None + } + }) + } + + /// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]. + /// + /// Each yielded item is a (node, key, value) tuple where key is a full un-truncated key (i.e., + /// with key[3] element unmodified). + /// + /// The iterator order is unspecified. + pub fn upper_leaves(&self) -> impl Iterator + '_ { + self.upper_leaves.iter().map(|(index, key)| { + let node = self.get_node_unchecked(index); + let value = self.get_value(*key); + (node, *key, value) + }) + } + + /// Returns an iterator over bottom leaves (i.e., depth = 64) of this [TieredSmt]. + /// + /// Each yielded item consists of the hash of the leaf and its contents, where contents is + /// a vector containing key-value pairs of entries storied in this leaf. Note that keys are + /// un-truncated keys (i.e., with key[3] element unmodified). + /// + /// The iterator order is unspecified. + pub fn bottom_leaves(&self) -> impl Iterator)> + '_ { + self.bottom_leaves.values().map(|leaf| (leaf.hash(), leaf.contents())) + } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Checks if the specified index is valid in the context of this Merkle tree. + /// + /// # Errors + /// Returns an error if: + /// - The specified index depth is 0 or greater than 64. + /// - The node for the specified index does not exists in the Merkle tree. This is possible + /// when an ancestors of the specified index is a leaf node. + fn validate_node_access(&self, index: NodeIndex) -> Result<(), MerkleError> { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > Self::MAX_DEPTH { + return Err(MerkleError::DepthTooBig(index.depth() as u64)); + } else { + // make sure that there are no leaf nodes in the ancestors of the index; since leaf + // nodes can live at specific depth, we just need to check these depths. + let tier = get_index_tier(&index); + let mut tier_index = index; + for &depth in Self::TIER_DEPTHS[..tier].iter().rev() { + tier_index.move_up_to(depth); + if self.upper_leaves.contains_key(&tier_index) { + return Err(MerkleError::NodeNotInSet(index)); + } + } + } + + Ok(()) + } + + /// Returns a node at the specified index. If the node does not exist at this index, a root + /// for an empty subtree at the index's depth is returned. + /// + /// Unlike [TieredSmt::get_node()] this does not perform any checks to verify that the returned + /// node is valid in the context of this tree. + fn get_node_unchecked(&self, index: &NodeIndex) -> RpoDigest { + match self.nodes.get(index) { + Some(node) => *node, + None => EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[index.depth() as usize], + } + } + + /// Returns an index at which a node for the specified key should be inserted. If a leaf node + /// already exists at that index, returns the key associated with that leaf node. + /// + /// In case the index falls into the bottom tier (depth = 64), leaf node key is not returned + /// as the bottom tier may contain multiple key-value pairs in the same leaf. + fn get_insert_location(&self, key: &RpoDigest) -> (NodeIndex, Option) { + // traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if + // a node at any of the tiers is either a leaf or a root of an empty subtree. + let mse = Word::from(key)[3].as_int(); + for depth in (Self::TIER_DEPTHS[0]..Self::MAX_DEPTH).step_by(Self::TIER_SIZE as usize) { + let index = NodeIndex::new_unchecked(depth, mse >> (Self::MAX_DEPTH - depth)); + if let Some(leaf_key) = self.upper_leaves.get(&index) { + return (index, Some(*leaf_key)); + } else if !self.nodes.contains_key(&index) { + return (index, None); + } + } + + // if we got here, that means all of the nodes checked so far are internal nodes, and + // the new node would need to be inserted in the bottom tier. + let index = NodeIndex::new_unchecked(Self::MAX_DEPTH, mse); + (index, None) + } + + /// Inserts the provided key-value pair at the specified index and updates the root of this + /// Merkle tree by recomputing the path to the root. + fn insert_node(&mut self, mut index: NodeIndex, key: RpoDigest, value: Word) { + let depth = index.depth(); + + // insert the key into index-key map and compute the new value of the node + let mut node = if index.depth() == Self::MAX_DEPTH { + // for the bottom tier, we add the key-value pair to the existing leaf, or create a + // new leaf with this key-value pair + self.bottom_leaves + .entry(index.value()) + .and_modify(|leaves| leaves.add_value(key, value)) + .or_insert(BottomLeaf::new(key, value)) + .hash() + } else { + // for the upper tiers, we just update the index-key map and compute the value of the + // node + self.upper_leaves.insert(index, key); + // the node value is computed as: hash(remaining_key || value, domain = depth) + let remaining_path = get_remaining_path(key, depth.into()); + Rpo256::merge_in_domain(&[remaining_path, value.into()], depth.into()) + }; + + // insert the node and update the path from the node to the root + for _ in 0..index.depth() { + self.nodes.insert(index, node); + let sibling = self.get_node_unchecked(&index.sibling()); + node = Rpo256::merge(&index.build_node(node, sibling)); + index.move_up(); + } + + // update the root + self.nodes.insert(NodeIndex::root(), node); + self.root = node; + } +} + +impl Default for TieredSmt { + fn default() -> Self { + Self { + root: EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[0], + nodes: BTreeMap::new(), + upper_leaves: BTreeMap::new(), + bottom_leaves: BTreeMap::new(), + values: BTreeMap::new(), + } + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Returns the remaining path for the specified key at the specified depth. +/// +/// Remaining path is computed by setting n most significant bits of the key to zeros, where n is +/// the specified depth. +fn get_remaining_path(key: RpoDigest, depth: u32) -> RpoDigest { + let mut key = Word::from(key); + key[3] = if depth == 64 { + ZERO + } else { + // remove `depth` bits from the most significant key element + ((key[3].as_int() << depth) >> depth).into() + }; + key.into() +} + +/// Returns index for the specified key inserted at the specified depth. +/// +/// The value for the key is computed by taking n most significant bits from the most significant +/// element of the key, where n is the specified depth. +fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex { + let mse = Word::from(key)[3].as_int(); + let value = match depth { + 16 | 32 | 48 | 64 => mse >> ((TieredSmt::MAX_DEPTH - depth) as u32), + _ => unreachable!("invalid depth: {depth}"), + }; + NodeIndex::new_unchecked(depth, value) +} + +/// Returns tiered common prefix length between the most significant elements of the provided keys. +/// +/// Specifically: +/// - returns 64 if the most significant elements are equal. +/// - returns 48 if the common prefix is between 48 and 63 bits. +/// - returns 32 if the common prefix is between 32 and 47 bits. +/// - returns 16 if the common prefix is between 16 and 31 bits. +/// - returns 0 if the common prefix is fewer than 16 bits. +fn get_common_prefix_tier(key1: &RpoDigest, key2: &RpoDigest) -> u8 { + let e1 = Word::from(key1)[3].as_int(); + let e2 = Word::from(key2)[3].as_int(); + let ex = (e1 ^ e2).leading_zeros() as u8; + (ex / 16) * 16 +} + +/// Returns a tier for the specified index. +/// +/// The tiers are defined as follows: +/// - Tier 0: depth 0 through 16 (inclusive). +/// - Tier 1: depth 17 through 32 (inclusive). +/// - Tier 2: depth 33 through 48 (inclusive). +/// - Tier 3: depth 49 through 64 (inclusive). +const fn get_index_tier(index: &NodeIndex) -> usize { + debug_assert!(index.depth() <= TieredSmt::MAX_DEPTH, "invalid depth"); + match index.depth() { + 0..=16 => 0, + 17..=32 => 1, + 33..=48 => 2, + _ => 3, + } +} + +/// Returns true if the specified index is an index for an inner node (i.e., the depth is not 16, +/// 32, 48, or 64). +const fn is_inner_node(index: &NodeIndex) -> bool { + !matches!(index.depth(), 16 | 32 | 48 | 64) +} + +// BOTTOM LEAF +// ================================================================================================ + +/// Stores contents of the bottom leaf (i.e., leaf at depth = 64) in a [TieredSmt]. +/// +/// Bottom leaf can contain one or more key-value pairs all sharing the same 64-bit key prefix. +/// The values are sorted by key to make sure the structure of the leaf is independent of the +/// insertion order. This guarantees that a leaf with the same set of key-value pairs always has +/// the same hash value. +#[derive(Debug, Clone, PartialEq, Eq)] +struct BottomLeaf { + prefix: u64, + values: BTreeMap<[u64; 4], Word>, +} + +impl BottomLeaf { + /// Returns a new [BottomLeaf] with a single key-value pair added. + pub fn new(key: RpoDigest, value: Word) -> Self { + let prefix = Word::from(key)[3].as_int(); + let mut values = BTreeMap::new(); + let key = get_remaining_path(key, TieredSmt::MAX_DEPTH as u32); + values.insert(key.into(), value); + Self { prefix, values } + } + + /// Adds a new key-value pair to this leaf. + pub fn add_value(&mut self, key: RpoDigest, value: Word) { + let key = get_remaining_path(key, TieredSmt::MAX_DEPTH as u32); + self.values.insert(key.into(), value); + } + + /// Computes a hash of this leaf. + pub fn hash(&self) -> RpoDigest { + let mut elements = Vec::with_capacity(self.values.len() * 2); + for (key, val) in self.values.iter() { + key.iter().for_each(|&v| elements.push(Felt::new(v))); + elements.extend_from_slice(val); + } + // TODO: hash in domain + Rpo256::hash_elements(&elements) + } + + /// Returns contents of this leaf as a vector of (key, value) pairs. + /// + /// The keys are returned in their un-truncated form. + pub fn contents(&self) -> Vec<(RpoDigest, Word)> { + self.values + .iter() + .map(|(key, val)| { + let key = RpoDigest::from([ + Felt::new(key[0]), + Felt::new(key[1]), + Felt::new(key[2]), + Felt::new(self.prefix), + ]); + (key, *val) + }) + .collect() + } +} diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs new file mode 100644 index 00000000..f46a4aeb --- /dev/null +++ b/src/merkle/tiered_smt/tests.rs @@ -0,0 +1,441 @@ +use super::{ + super::{super::ONE, Felt, MerkleStore, WORD_SIZE, ZERO}, + get_remaining_path, EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, + Vec, Word, +}; + +#[test] +fn tsmt_insert_one() { + let mut smt = TieredSmt::default(); + let mut store = MerkleStore::default(); + + let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let value = [ONE; WORD_SIZE]; + + // since the tree is empty, the first node will be inserted at depth 16 and the index will be + // 16 most significant bits of the key + let index = NodeIndex::make(16, raw >> 48); + let leaf_node = build_leaf_node(key, value, 16); + let tree_root = store.set_node(smt.root().into(), index, leaf_node.into()).unwrap().root; + + smt.insert(key, value); + + assert_eq!(smt.root(), tree_root.into()); + + // make sure the value was inserted, and the node is at the expected index + assert_eq!(smt.get_value(key), value); + assert_eq!(smt.get_node(index).unwrap(), leaf_node); + + // make sure the paths we get from the store and the tree match + let expected_path = store.get_path(tree_root, index).unwrap(); + assert_eq!(smt.get_path(index).unwrap(), expected_path.path); + + // make sure inner nodes match + let expected_nodes = get_non_empty_nodes(&store); + let actual_nodes = smt.inner_nodes().collect::>(); + assert_eq!(actual_nodes.len(), expected_nodes.len()); + actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); + + // make sure leaves are returned correctly + let mut leaves = smt.upper_leaves(); + assert_eq!(leaves.next(), Some((leaf_node, key, value))); + assert_eq!(leaves.next(), None); +} + +#[test] +fn tsmt_insert_two_16() { + let mut smt = TieredSmt::default(); + let mut store = MerkleStore::default(); + + // --- insert the first value --------------------------------------------- + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a = [ONE; WORD_SIZE]; + smt.insert(key_a, val_a); + + // --- insert the second value -------------------------------------------- + // the key for this value has the same 16-bit prefix as the key for the first value, + // thus, on insertions, both values should be pushed to depth 32 tier + let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let val_b = [Felt::new(2); WORD_SIZE]; + smt.insert(key_b, val_b); + + // --- build Merkle store with equivalent data ---------------------------- + let mut tree_root = get_init_root(); + let index_a = NodeIndex::make(32, raw_a >> 32); + let leaf_node_a = build_leaf_node(key_a, val_a, 32); + tree_root = store.set_node(tree_root, index_a, leaf_node_a.into()).unwrap().root; + + let index_b = NodeIndex::make(32, raw_b >> 32); + let leaf_node_b = build_leaf_node(key_b, val_b, 32); + tree_root = store.set_node(tree_root, index_b, leaf_node_b.into()).unwrap().root; + + // --- verify that data is consistent between store and tree -------------- + + assert_eq!(smt.root(), tree_root.into()); + + assert_eq!(smt.get_value(key_a), val_a); + assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a); + let expected_path = store.get_path(tree_root, index_a).unwrap().path; + assert_eq!(smt.get_path(index_a).unwrap(), expected_path); + + assert_eq!(smt.get_value(key_b), val_b); + assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b); + let expected_path = store.get_path(tree_root, index_b).unwrap().path; + assert_eq!(smt.get_path(index_b).unwrap(), expected_path); + + // make sure inner nodes match - the store contains more entries because it keeps track of + // all prior state - so, we don't check that the number of inner nodes is the same in both + let expected_nodes = get_non_empty_nodes(&store); + let actual_nodes = smt.inner_nodes().collect::>(); + actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); + + // make sure leaves are returned correctly + let mut leaves = smt.upper_leaves(); + assert_eq!(leaves.next(), Some((leaf_node_a, key_a, val_a))); + assert_eq!(leaves.next(), Some((leaf_node_b, key_b, val_b))); + assert_eq!(leaves.next(), None); +} + +#[test] +fn tsmt_insert_two_32() { + let mut smt = TieredSmt::default(); + let mut store = MerkleStore::default(); + + // --- insert the first value --------------------------------------------- + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a = [ONE; WORD_SIZE]; + smt.insert(key_a, val_a); + + // --- insert the second value -------------------------------------------- + // the key for this value has the same 32-bit prefix as the key for the first value, + // thus, on insertions, both values should be pushed to depth 48 tier + let raw_b = 0b_10101010_10101010_00011111_11111111_00010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let val_b = [Felt::new(2); WORD_SIZE]; + smt.insert(key_b, val_b); + + // --- build Merkle store with equivalent data ---------------------------- + let mut tree_root = get_init_root(); + let index_a = NodeIndex::make(48, raw_a >> 16); + let leaf_node_a = build_leaf_node(key_a, val_a, 48); + tree_root = store.set_node(tree_root, index_a, leaf_node_a.into()).unwrap().root; + + let index_b = NodeIndex::make(48, raw_b >> 16); + let leaf_node_b = build_leaf_node(key_b, val_b, 48); + tree_root = store.set_node(tree_root, index_b, leaf_node_b.into()).unwrap().root; + + // --- verify that data is consistent between store and tree -------------- + + assert_eq!(smt.root(), tree_root.into()); + + assert_eq!(smt.get_value(key_a), val_a); + assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a); + let expected_path = store.get_path(tree_root, index_a).unwrap().path; + assert_eq!(smt.get_path(index_a).unwrap(), expected_path); + + assert_eq!(smt.get_value(key_b), val_b); + assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b); + let expected_path = store.get_path(tree_root, index_b).unwrap().path; + assert_eq!(smt.get_path(index_b).unwrap(), expected_path); + + // make sure inner nodes match - the store contains more entries because it keeps track of + // all prior state - so, we don't check that the number of inner nodes is the same in both + let expected_nodes = get_non_empty_nodes(&store); + let actual_nodes = smt.inner_nodes().collect::>(); + actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); +} + +#[test] +fn tsmt_insert_three() { + let mut smt = TieredSmt::default(); + let mut store = MerkleStore::default(); + + // --- insert the first value --------------------------------------------- + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a = [ONE; WORD_SIZE]; + smt.insert(key_a, val_a); + + // --- insert the second value -------------------------------------------- + // the key for this value has the same 16-bit prefix as the key for the first value, + // thus, on insertions, both values should be pushed to depth 32 tier + let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let val_b = [Felt::new(2); WORD_SIZE]; + smt.insert(key_b, val_b); + + // --- insert the third value --------------------------------------------- + // the key for this value has the same 16-bit prefix as the keys for the first two, + // values; thus, on insertions, it will be inserted into depth 32 tier, but will not + // affect locations of the other two values + let raw_c = 0b_10101010_10101010_11011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let val_c = [Felt::new(3); WORD_SIZE]; + smt.insert(key_c, val_c); + + // --- build Merkle store with equivalent data ---------------------------- + let mut tree_root = get_init_root(); + let index_a = NodeIndex::make(32, raw_a >> 32); + let leaf_node_a = build_leaf_node(key_a, val_a, 32); + tree_root = store.set_node(tree_root, index_a, leaf_node_a.into()).unwrap().root; + + let index_b = NodeIndex::make(32, raw_b >> 32); + let leaf_node_b = build_leaf_node(key_b, val_b, 32); + tree_root = store.set_node(tree_root, index_b, leaf_node_b.into()).unwrap().root; + + let index_c = NodeIndex::make(32, raw_c >> 32); + let leaf_node_c = build_leaf_node(key_c, val_c, 32); + tree_root = store.set_node(tree_root, index_c, leaf_node_c.into()).unwrap().root; + + // --- verify that data is consistent between store and tree -------------- + + assert_eq!(smt.root(), tree_root.into()); + + assert_eq!(smt.get_value(key_a), val_a); + assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a); + let expected_path = store.get_path(tree_root, index_a).unwrap().path; + assert_eq!(smt.get_path(index_a).unwrap(), expected_path); + + assert_eq!(smt.get_value(key_b), val_b); + assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b); + let expected_path = store.get_path(tree_root, index_b).unwrap().path; + assert_eq!(smt.get_path(index_b).unwrap(), expected_path); + + assert_eq!(smt.get_value(key_c), val_c); + assert_eq!(smt.get_node(index_c).unwrap(), leaf_node_c); + let expected_path = store.get_path(tree_root, index_c).unwrap().path; + assert_eq!(smt.get_path(index_c).unwrap(), expected_path); + + // make sure inner nodes match - the store contains more entries because it keeps track of + // all prior state - so, we don't check that the number of inner nodes is the same in both + let expected_nodes = get_non_empty_nodes(&store); + let actual_nodes = smt.inner_nodes().collect::>(); + actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); +} + +#[test] +fn tsmt_update() { + let mut smt = TieredSmt::default(); + let mut store = MerkleStore::default(); + + // --- insert a value into the tree --------------------------------------- + let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let value_a = [ONE; WORD_SIZE]; + smt.insert(key, value_a); + + // --- update the value --------------------------------------------------- + let value_b = [Felt::new(2); WORD_SIZE]; + smt.insert(key, value_b); + + // --- verify consistency ------------------------------------------------- + let mut tree_root = get_init_root(); + let index = NodeIndex::make(16, raw >> 48); + let leaf_node = build_leaf_node(key, value_b, 16); + tree_root = store.set_node(tree_root, index, leaf_node.into()).unwrap().root; + + assert_eq!(smt.root(), tree_root.into()); + + assert_eq!(smt.get_value(key), value_b); + assert_eq!(smt.get_node(index).unwrap(), leaf_node); + let expected_path = store.get_path(tree_root, index).unwrap().path; + assert_eq!(smt.get_path(index).unwrap(), expected_path); + + // make sure inner nodes match - the store contains more entries because it keeps track of + // all prior state - so, we don't check that the number of inner nodes is the same in both + let expected_nodes = get_non_empty_nodes(&store); + let actual_nodes = smt.inner_nodes().collect::>(); + actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); +} + +// BOTTOM TIER TESTS +// ================================================================================================ + +#[test] +fn tsmt_bottom_tier() { + let mut smt = TieredSmt::default(); + let mut store = MerkleStore::default(); + + // common prefix for the keys + let prefix = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + + // --- insert the first value --------------------------------------------- + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(prefix)]); + let val_a = [ONE; WORD_SIZE]; + smt.insert(key_a, val_a); + + // --- insert the second value -------------------------------------------- + // this key has the same 64-bit prefix and thus both values should end up in the same + // node at depth 64 + let key_b = RpoDigest::from([ZERO, ONE, ONE, Felt::new(prefix)]); + let val_b = [Felt::new(2); WORD_SIZE]; + smt.insert(key_b, val_b); + + // --- build Merkle store with equivalent data ---------------------------- + let index = NodeIndex::make(64, prefix); + // to build bottom leaf we sort by key starting with the least significant element, thus + // key_b is smaller than key_a. + let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a]); + let mut tree_root = get_init_root(); + tree_root = store.set_node(tree_root, index, leaf_node.into()).unwrap().root; + + // --- verify that data is consistent between store and tree -------------- + + assert_eq!(smt.root(), tree_root.into()); + + assert_eq!(smt.get_value(key_a), val_a); + assert_eq!(smt.get_value(key_b), val_b); + + assert_eq!(smt.get_node(index).unwrap(), leaf_node); + let expected_path = store.get_path(tree_root, index).unwrap().path; + assert_eq!(smt.get_path(index).unwrap(), expected_path); + + // make sure inner nodes match - the store contains more entries because it keeps track of + // all prior state - so, we don't check that the number of inner nodes is the same in both + let expected_nodes = get_non_empty_nodes(&store); + let actual_nodes = smt.inner_nodes().collect::>(); + actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); + + // make sure leaves are returned correctly + let mut leaves = smt.bottom_leaves(); + assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a)]))); + assert_eq!(leaves.next(), None); +} + +#[test] +fn tsmt_bottom_tier_two() { + let mut smt = TieredSmt::default(); + let mut store = MerkleStore::default(); + + // --- insert the first value --------------------------------------------- + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a = [ONE; WORD_SIZE]; + smt.insert(key_a, val_a); + + // --- insert the second value -------------------------------------------- + // the key for this value has the same 48-bit prefix as the key for the first value, + // thus, on insertions, both should end up in different nodes at depth 64 + let raw_b = 0b_10101010_10101010_00011111_11111111_10010110_10010011_01100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let val_b = [Felt::new(2); WORD_SIZE]; + smt.insert(key_b, val_b); + + // --- build Merkle store with equivalent data ---------------------------- + let mut tree_root = get_init_root(); + let index_a = NodeIndex::make(64, raw_a); + let leaf_node_a = build_bottom_leaf_node(&[key_a], &[val_a]); + tree_root = store.set_node(tree_root, index_a, leaf_node_a.into()).unwrap().root; + + let index_b = NodeIndex::make(64, raw_b); + let leaf_node_b = build_bottom_leaf_node(&[key_b], &[val_b]); + tree_root = store.set_node(tree_root, index_b, leaf_node_b.into()).unwrap().root; + + // --- verify that data is consistent between store and tree -------------- + + assert_eq!(smt.root(), tree_root.into()); + + assert_eq!(smt.get_value(key_a), val_a); + assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a); + let expected_path = store.get_path(tree_root, index_a).unwrap().path; + assert_eq!(smt.get_path(index_a).unwrap(), expected_path); + + assert_eq!(smt.get_value(key_b), val_b); + assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b); + let expected_path = store.get_path(tree_root, index_b).unwrap().path; + assert_eq!(smt.get_path(index_b).unwrap(), expected_path); + + // make sure inner nodes match - the store contains more entries because it keeps track of + // all prior state - so, we don't check that the number of inner nodes is the same in both + let expected_nodes = get_non_empty_nodes(&store); + let actual_nodes = smt.inner_nodes().collect::>(); + actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); + + // make sure leaves are returned correctly + let mut leaves = smt.bottom_leaves(); + assert_eq!(leaves.next(), Some((leaf_node_b, vec![(key_b, val_b)]))); + assert_eq!(leaves.next(), Some((leaf_node_a, vec![(key_a, val_a)]))); + assert_eq!(leaves.next(), None); +} + +// ERROR TESTS +// ================================================================================================ + +#[test] +fn tsmt_node_not_available() { + let mut smt = TieredSmt::default(); + + let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let value = [ONE; WORD_SIZE]; + + // build an index which is just below the inserted leaf node + let index = NodeIndex::make(17, raw >> 47); + + // since we haven't inserted the node yet, we should be able to get node and path to this index + assert!(smt.get_node(index).is_ok()); + assert!(smt.get_path(index).is_ok()); + + smt.insert(key, value); + + // but once the node is inserted, everything under it should be unavailable + assert!(smt.get_node(index).is_err()); + assert!(smt.get_path(index).is_err()); + + let index = NodeIndex::make(32, raw >> 32); + assert!(smt.get_node(index).is_err()); + assert!(smt.get_path(index).is_err()); + + let index = NodeIndex::make(34, raw >> 30); + assert!(smt.get_node(index).is_err()); + assert!(smt.get_path(index).is_err()); + + let index = NodeIndex::make(50, raw >> 14); + assert!(smt.get_node(index).is_err()); + assert!(smt.get_path(index).is_err()); + + let index = NodeIndex::make(64, raw); + assert!(smt.get_node(index).is_err()); + assert!(smt.get_path(index).is_err()); +} + +// HELPER FUNCTIONS +// ================================================================================================ + +fn get_init_root() -> Word { + EmptySubtreeRoots::empty_hashes(64)[0].into() +} + +fn build_leaf_node(key: RpoDigest, value: Word, depth: u8) -> RpoDigest { + let remaining_path = get_remaining_path(key, depth as u32); + Rpo256::merge_in_domain(&[remaining_path, value.into()], depth.into()) +} + +fn build_bottom_leaf_node(keys: &[RpoDigest], values: &[Word]) -> RpoDigest { + assert_eq!(keys.len(), values.len()); + + let mut elements = Vec::with_capacity(keys.len()); + for (key, val) in keys.iter().zip(values.iter()) { + let mut key = Word::from(key); + key[3] = ZERO; + elements.extend_from_slice(&key); + elements.extend_from_slice(val); + } + + Rpo256::hash_elements(&elements) +} + +fn get_non_empty_nodes(store: &MerkleStore) -> Vec { + store + .inner_nodes() + .filter(|node| !is_empty_subtree(&RpoDigest::from(node.value))) + .collect::>() +} + +fn is_empty_subtree(node: &RpoDigest) -> bool { + EmptySubtreeRoots::empty_hashes(255).contains(&node) +}