From 340b55436e1ee4158f267f9137dc6f9fd62aac56 Mon Sep 17 00:00:00 2001 From: Victor Lopez Date: Thu, 8 Dec 2022 17:13:17 +0100 Subject: [PATCH] feat: add simple sparse merkle tree This commit moves the previous implementation of `SparseMerkleTree` from miden-core to this crate. It also include a couple of new tests, a bench suite, and a couple of minor fixes. The original API was preserved to maintain compatibility with `AdviceTape`. closes #21 --- Cargo.toml | 4 + benches/smt.rs | 89 +++++++++++ src/hash/mod.rs | 3 +- src/lib.rs | 6 +- src/merkle/merkle_tree.rs | 13 +- src/merkle/mod.rs | 36 ++++- src/merkle/simple_smt/mod.rs | 272 +++++++++++++++++++++++++++++++++ src/merkle/simple_smt/tests.rs | 263 +++++++++++++++++++++++++++++++ 8 files changed, 675 insertions(+), 11 deletions(-) create mode 100644 benches/smt.rs create mode 100644 src/merkle/simple_smt/mod.rs create mode 100644 src/merkle/simple_smt/tests.rs diff --git a/Cargo.toml b/Cargo.toml index dc7367bd..bcb9dbd1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,10 @@ edition = "2021" name = "hash" harness = false +[[bench]] +name = "smt" +harness = false + [features] default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"] diff --git a/benches/smt.rs b/benches/smt.rs new file mode 100644 index 00000000..e23aa73e --- /dev/null +++ b/benches/smt.rs @@ -0,0 +1,89 @@ +use core::mem::swap; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use miden_crypto::{merkle::SimpleSmt, Felt, Word}; +use rand_utils::prng_array; + +fn smt_rpo(c: &mut Criterion) { + // parameters + + const DEPTH: u32 = 16; + const LEAVES: u64 = ((1 << DEPTH) - 1) as u64; + const KEY: u64 = (LEAVES) >> 2; + + let mut seed = [0u8; 32]; + + // setup trees + + let mut trees: Vec<_> = [1, LEAVES / 2, LEAVES] + .into_iter() + .scan([0u8; 32], |seed, count| { + let tree = create_simple_smt::(count, seed); + Some(tree) + }) + .collect(); + + let leaf = generate_word(&mut seed); + + // benchmarks + + let mut insert = c.benchmark_group(format!("smt update_leaf(depth{DEPTH})")); + + for tree in trees.iter_mut() { + let count = tree.leaves_count() as u64; + insert.bench_with_input( + format!("simple smt({count})"), + &(KEY % count.max(1), leaf), + |b, (key, leaf)| { + b.iter(|| { + tree.update_leaf(black_box(*key), black_box(*leaf)).unwrap(); + }); + }, + ); + } + + insert.finish(); + + let mut path = c.benchmark_group(format!("smt get_leaf_path(depth{DEPTH})")); + + for tree in trees.iter_mut() { + let count = tree.leaves_count() as u64; + path.bench_with_input( + format!("simple smt({count})"), + &(KEY % count.max(1)), + |b, key| { + b.iter(|| { + tree.get_leaf_path(black_box(*key)).unwrap(); + }); + }, + ); + } + + path.finish(); +} + +criterion_group!(smt_group, smt_rpo); +criterion_main!(smt_group); + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn generate_word(seed: &mut [u8; 32]) -> Word { + swap(seed, &mut prng_array(*seed)); + let nums: [u64; 4] = prng_array(*seed); + [ + Felt::new(nums[0]), + Felt::new(nums[1]), + Felt::new(nums[2]), + Felt::new(nums[3]), + ] +} + +fn create_simple_smt(count: u64, seed: &mut [u8; 32]) -> SimpleSmt { + let entries: Vec<_> = (0..count) + .map(|i| { + let word = generate_word(seed); + (i, word) + }) + .collect(); + SimpleSmt::new(entries, DEPTH).unwrap() +} diff --git a/src/hash/mod.rs b/src/hash/mod.rs index 9508754c..ee87895b 100644 --- a/src/hash/mod.rs +++ b/src/hash/mod.rs @@ -1,5 +1,4 @@ -use super::{Felt, FieldElement, StarkField, ONE, ZERO}; -use winter_crypto::{Digest, ElementHasher, Hasher}; +use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO}; pub mod blake; pub mod rpo; diff --git a/src/lib.rs b/src/lib.rs index ff2eb43f..9782af1a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ pub mod merkle; // RE-EXPORTS // ================================================================================================ +pub use winter_crypto::{Digest, ElementHasher, Hasher}; pub use winter_math::{fields::f64::BaseElement as Felt, FieldElement, StarkField}; pub mod utils { @@ -23,11 +24,14 @@ pub mod utils { // ================================================================================================ /// A group of four field elements in the Miden base field. -pub type Word = [Felt; 4]; +pub type Word = [Felt; WORD_SIZE]; // CONSTANTS // ================================================================================================ +/// Number of field elements in a word. +pub const WORD_SIZE: usize = 4; + /// Field element representing ZERO in the Miden base filed. pub const ZERO: Felt = Felt::ZERO; diff --git a/src/merkle/merkle_tree.rs b/src/merkle/merkle_tree.rs index 8763d96b..c0c381cd 100644 --- a/src/merkle/merkle_tree.rs +++ b/src/merkle/merkle_tree.rs @@ -1,4 +1,4 @@ -use super::{Digest, Felt, MerkleError, Rpo256, Vec, Word}; +use super::{Felt, MerkleError, Rpo256, RpoDigest, Vec, Word}; use crate::{utils::uninit_vector, FieldElement}; use core::slice; use winter_math::log2; @@ -22,7 +22,7 @@ impl MerkleTree { pub fn new(leaves: Vec) -> Result { let n = leaves.len(); if n <= 1 { - return Err(MerkleError::DepthTooSmall); + return Err(MerkleError::DepthTooSmall(n as u32)); } else if !n.is_power_of_two() { return Err(MerkleError::NumLeavesNotPowerOfTwo(n)); } @@ -35,7 +35,8 @@ impl MerkleTree { nodes[n..].copy_from_slice(&leaves); // re-interpret nodes as an array of two nodes fused together - let two_nodes = unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [Digest; 2], n) }; + let two_nodes = + unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [RpoDigest; 2], n) }; // calculate all internal tree nodes for i in (1..n).rev() { @@ -68,7 +69,7 @@ impl MerkleTree { /// * The specified index not valid for the specified depth. pub fn get_node(&self, depth: u32, index: u64) -> Result { if depth == 0 { - return Err(MerkleError::DepthTooSmall); + return Err(MerkleError::DepthTooSmall(depth)); } else if depth > self.depth() { return Err(MerkleError::DepthTooBig(depth)); } @@ -89,7 +90,7 @@ impl MerkleTree { /// * The specified index not valid for the specified depth. pub fn get_path(&self, depth: u32, index: u64) -> Result, MerkleError> { if depth == 0 { - return Err(MerkleError::DepthTooSmall); + return Err(MerkleError::DepthTooSmall(depth)); } else if depth > self.depth() { return Err(MerkleError::DepthTooBig(depth)); } @@ -123,7 +124,7 @@ impl MerkleTree { let n = self.nodes.len() / 2; let two_nodes = - unsafe { slice::from_raw_parts(self.nodes.as_ptr() as *const [Digest; 2], n) }; + unsafe { slice::from_raw_parts(self.nodes.as_ptr() as *const [RpoDigest; 2], n) }; for _ in 0..depth { index /= 2; diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 1b137b9b..87cd80f5 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -1,8 +1,9 @@ use super::{ - hash::rpo::{Rpo256, RpoDigest as Digest}, + hash::rpo::{Rpo256, RpoDigest}, utils::collections::{BTreeMap, Vec}, Felt, Word, ZERO, }; +use core::fmt; mod merkle_tree; pub use merkle_tree::MerkleTree; @@ -10,20 +11,51 @@ pub use merkle_tree::MerkleTree; mod merkle_path_set; pub use merkle_path_set::MerklePathSet; +mod simple_smt; +pub use simple_smt::SimpleSmt; + // ERRORS // ================================================================================================ #[derive(Clone, Debug)] pub enum MerkleError { - DepthTooSmall, + DepthTooSmall(u32), DepthTooBig(u32), NumLeavesNotPowerOfTwo(usize), InvalidIndex(u32, u64), InvalidDepth(u32, u32), InvalidPath(Vec), + InvalidEntriesCount(usize, usize), NodeNotInSet(u64), } +impl fmt::Display for MerkleError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use MerkleError::*; + match self { + DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"), + DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"), + NumLeavesNotPowerOfTwo(leaves) => { + write!(f, "the leaves count {leaves} is not a power of 2") + } + InvalidIndex(depth, index) => write!( + f, + "the leaf index {index} is not valid for the depth {depth}" + ), + InvalidDepth(expected, provided) => write!( + f, + "the provided depth {provided} is not valid for {expected}" + ), + InvalidPath(_path) => write!(f, "the provided path is not valid"), + InvalidEntriesCount(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"), + NodeNotInSet(index) => write!(f, "the node indexed by {index} is not in the set"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for MerkleError {} + // HELPER FUNCTIONS // ================================================================================================ diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs new file mode 100644 index 00000000..3994b6b9 --- /dev/null +++ b/src/merkle/simple_smt/mod.rs @@ -0,0 +1,272 @@ +use super::{BTreeMap, MerkleError, Rpo256, RpoDigest, Vec, Word}; + +#[cfg(test)] +mod tests; + +// SPARSE MERKLE TREE +// ================================================================================================ + +/// A sparse Merkle tree with 63-bit keys and 4-element leaf values, without compaction. +/// Manipulation and retrieval of leaves and internal nodes is provided by its internal `Store`. +/// The root of the tree is recomputed on each new leaf update. +#[derive(Clone, Debug)] +pub struct SimpleSmt { + root: Word, + depth: u32, + store: Store, +} + +impl SimpleSmt { + // CONSTANTS + // -------------------------------------------------------------------------------------------- + + /// Minimum supported depth. + pub const MIN_DEPTH: u32 = 1; + + /// Maximum supported depth. + pub const MAX_DEPTH: u32 = 63; + + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + /// Creates a new simple SMT. + /// + /// The provided entries will be tuples of the leaves and their corresponding keys. + /// + /// # Errors + /// + /// The function will fail if the provided entries count exceed the maximum tree capacity, that + /// is `2^{depth}`. + pub fn new(entries: R, depth: u32) -> Result + where + R: IntoIterator, + I: Iterator + ExactSizeIterator, + { + let mut entries = entries.into_iter(); + + // validate the range of the depth. + let max = 1 << depth; + if depth < Self::MIN_DEPTH { + return Err(MerkleError::DepthTooSmall(depth)); + } else if Self::MAX_DEPTH < depth { + return Err(MerkleError::DepthTooBig(depth)); + } else if entries.len() > max { + return Err(MerkleError::InvalidEntriesCount(max, entries.len())); + } + + let (store, root) = Store::new(depth); + let mut tree = Self { root, depth, store }; + entries.try_for_each(|(key, leaf)| tree.insert_leaf(key, leaf))?; + + Ok(tree) + } + + /// Returns the root of this Merkle tree. + pub const fn root(&self) -> Word { + self.root + } + + /// Returns the depth of this Merkle tree. + pub const fn depth(&self) -> u32 { + self.depth + } + + /// Returns the set count of the keys of the leaves. + pub fn leaves_count(&self) -> usize { + self.store.leaves_count() + } + + /// Returns a node at the specified key + /// + /// # Errors + /// Returns an error if: + /// * The specified depth is greater than the depth of the tree. + /// * The specified key does not exist + pub fn get_node(&self, depth: u32, key: u64) -> Result { + if depth == 0 { + Err(MerkleError::DepthTooSmall(depth)) + } else if depth > self.depth() { + Err(MerkleError::DepthTooBig(depth)) + } else if depth == self.depth() { + self.store.get_leaf_node(key) + } else { + let branch_node = self.store.get_branch_node(key, depth)?; + Ok(Rpo256::merge(&[branch_node.left, branch_node.right]).into()) + } + } + + /// Returns a Merkle path to the node at the specified key. The node itself is + /// not included in the path. + /// + /// # Errors + /// Returns an error if: + /// * The specified key does not exist as a branch or leaf node + /// * The specified depth is greater than the depth of the tree. + pub fn get_path(&self, depth: u32, key: u64) -> Result, MerkleError> { + if depth == 0 { + return Err(MerkleError::DepthTooSmall(depth)); + } else if depth > self.depth() { + return Err(MerkleError::DepthTooBig(depth)); + } else if depth == self.depth() && !self.store.check_leaf_node_exists(key) { + return Err(MerkleError::InvalidIndex(self.depth, key)); + } + + let mut path = Vec::with_capacity(depth as usize); + let mut curr_key = key; + for n in (0..depth).rev() { + let parent_key = curr_key >> 1; + let parent_node = self.store.get_branch_node(parent_key, n)?; + let sibling_node = if curr_key & 1 == 1 { + parent_node.left + } else { + parent_node.right + }; + path.push(sibling_node.into()); + curr_key >>= 1; + } + Ok(path) + } + + /// Returns a Merkle path to the node at the specified key. The node itself is + /// not included in the path. + /// + /// # Errors + /// Returns an error if: + /// * The specified key does not exist as a leaf node + pub fn get_leaf_path(&self, key: u64) -> Result, MerkleError> { + if !self.store.check_leaf_node_exists(key) { + return Err(MerkleError::InvalidIndex(self.depth, key)); + } + self.get_path(self.depth, key) + } + + /// Replaces the leaf located at the specified key, and recomputes hashes by walking up the tree + /// + /// # Errors + /// Returns an error if the specified key is not a valid leaf index for this tree. + pub fn update_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> { + if !self.store.check_leaf_node_exists(key) { + return Err(MerkleError::InvalidIndex(self.depth, key)); + } + self.insert_leaf(key, value)?; + + Ok(()) + } + + /// Inserts a leaf located at the specified key, and recomputes hashes by walking up the tree + pub fn insert_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> { + self.store.insert_leaf_node(key, value); + + let depth = self.depth(); + let mut curr_key = key; + let mut curr_node: RpoDigest = value.into(); + for n in (0..depth).rev() { + let parent_key = curr_key >> 1; + let parent_node = self + .store + .get_branch_node(parent_key, n) + .unwrap_or_else(|_| self.store.get_empty_node((n + 1) as usize)); + let (left, right) = if curr_key & 1 == 1 { + (parent_node.left, curr_node) + } else { + (curr_node, parent_node.right) + }; + + self.store.insert_branch_node(parent_key, n, left, right); + curr_key = parent_key; + curr_node = Rpo256::merge(&[left, right]); + } + self.root = curr_node.into(); + + Ok(()) + } +} + +// STORE +// ================================================================================================ + +/// A data store for sparse Merkle tree key-value pairs. +/// Leaves and branch nodes are stored separately in B-tree maps, indexed by key and (key, depth) +/// respectively. Hashes for blank subtrees at each layer are stored in `empty_hashes`, beginning +/// with the root hash of an empty tree, and ending with the zero value of a leaf node. +#[derive(Clone, Debug)] +struct Store { + branches: BTreeMap<(u64, u32), BranchNode>, + leaves: BTreeMap, + empty_hashes: Vec, + depth: u32, +} + +#[derive(Clone, Debug, Default)] +struct BranchNode { + left: RpoDigest, + right: RpoDigest, +} + +impl Store { + fn new(depth: u32) -> (Self, Word) { + let branches = BTreeMap::new(); + let leaves = BTreeMap::new(); + + // Construct empty node digests for each layer of the tree + let empty_hashes: Vec = (0..depth + 1) + .scan(Word::default().into(), |state, _| { + let value = *state; + *state = Rpo256::merge(&[value, value]); + Some(value) + }) + .collect::>() + .into_iter() + .rev() + .collect(); + + let root = empty_hashes[0].into(); + let store = Self { + branches, + leaves, + empty_hashes, + depth, + }; + + (store, root) + } + + fn get_empty_node(&self, depth: usize) -> BranchNode { + let digest = self.empty_hashes[depth]; + BranchNode { + left: digest, + right: digest, + } + } + + fn check_leaf_node_exists(&self, key: u64) -> bool { + self.leaves.contains_key(&key) + } + + fn get_leaf_node(&self, key: u64) -> Result { + self.leaves + .get(&key) + .cloned() + .ok_or(MerkleError::InvalidIndex(self.depth, key)) + } + + fn insert_leaf_node(&mut self, key: u64, node: Word) { + self.leaves.insert(key, node); + } + + fn get_branch_node(&self, key: u64, depth: u32) -> Result { + self.branches + .get(&(key, depth)) + .cloned() + .ok_or(MerkleError::InvalidIndex(depth, key)) + } + + fn insert_branch_node(&mut self, key: u64, depth: u32, left: RpoDigest, right: RpoDigest) { + let node = BranchNode { left, right }; + self.branches.insert((key, depth), node); + } + + fn leaves_count(&self) -> usize { + self.leaves.len() + } +} diff --git a/src/merkle/simple_smt/tests.rs b/src/merkle/simple_smt/tests.rs new file mode 100644 index 00000000..7042d1b6 --- /dev/null +++ b/src/merkle/simple_smt/tests.rs @@ -0,0 +1,263 @@ +use super::{ + super::{MerkleTree, RpoDigest, SimpleSmt}, + Rpo256, Vec, Word, +}; +use crate::{Felt, FieldElement}; +use core::iter; +use proptest::prelude::*; +use rand_utils::prng_array; + +const KEYS4: [u64; 4] = [0, 1, 2, 3]; +const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; + +const VALUES4: [Word; 4] = [ + int_to_node(1), + int_to_node(2), + int_to_node(3), + int_to_node(4), +]; + +const VALUES8: [Word; 8] = [ + int_to_node(1), + int_to_node(2), + int_to_node(3), + int_to_node(4), + int_to_node(5), + int_to_node(6), + int_to_node(7), + int_to_node(8), +]; + +const ZERO_VALUES8: [Word; 8] = [int_to_node(0); 8]; + +#[test] +fn build_empty_tree() { + let smt = SimpleSmt::new(iter::empty(), 3).unwrap(); + let mt = MerkleTree::new(ZERO_VALUES8.to_vec()).unwrap(); + assert_eq!(mt.root(), smt.root()); +} + +#[test] +fn empty_digests_are_consistent() { + let depth = 5; + let root = SimpleSmt::new(iter::empty(), depth).unwrap().root(); + let computed: [RpoDigest; 2] = (0..depth).fold([Default::default(); 2], |state, _| { + let digest = Rpo256::merge(&state); + [digest; 2] + }); + + assert_eq!(Word::from(computed[0]), root); +} + +#[test] +fn build_sparse_tree() { + let mut smt = SimpleSmt::new(iter::empty(), 3).unwrap(); + let mut values = ZERO_VALUES8.to_vec(); + + // insert single value + let key = 6; + let new_node = int_to_node(7); + values[key as usize] = new_node; + smt.insert_leaf(key, new_node) + .expect("Failed to insert leaf"); + let mt2 = MerkleTree::new(values.clone()).unwrap(); + assert_eq!(mt2.root(), smt.root()); + assert_eq!(mt2.get_path(3, 6).unwrap(), smt.get_path(3, 6).unwrap()); + + // insert second value at distinct leaf branch + let key = 2; + let new_node = int_to_node(3); + values[key as usize] = new_node; + smt.insert_leaf(key, new_node) + .expect("Failed to insert leaf"); + let mt3 = MerkleTree::new(values).unwrap(); + assert_eq!(mt3.root(), smt.root()); + assert_eq!(mt3.get_path(3, 2).unwrap(), smt.get_path(3, 2).unwrap()); +} + +#[test] +fn build_full_tree() { + let tree = SimpleSmt::new(KEYS4.into_iter().zip(VALUES4.into_iter()), 2).unwrap(); + + let (root, node2, node3) = compute_internal_nodes(); + assert_eq!(root, tree.root()); + assert_eq!(node2, tree.get_node(1, 0).unwrap()); + assert_eq!(node3, tree.get_node(1, 1).unwrap()); +} + +#[test] +fn get_values() { + let tree = SimpleSmt::new(KEYS4.into_iter().zip(VALUES4.into_iter()), 2).unwrap(); + + // check depth 2 + assert_eq!(VALUES4[0], tree.get_node(2, 0).unwrap()); + assert_eq!(VALUES4[1], tree.get_node(2, 1).unwrap()); + assert_eq!(VALUES4[2], tree.get_node(2, 2).unwrap()); + assert_eq!(VALUES4[3], tree.get_node(2, 3).unwrap()); +} + +#[test] +fn get_path() { + let tree = SimpleSmt::new(KEYS4.into_iter().zip(VALUES4.into_iter()), 2).unwrap(); + + let (_, node2, node3) = compute_internal_nodes(); + + // check depth 2 + assert_eq!(vec![VALUES4[1], node3], tree.get_path(2, 0).unwrap()); + assert_eq!(vec![VALUES4[0], node3], tree.get_path(2, 1).unwrap()); + assert_eq!(vec![VALUES4[3], node2], tree.get_path(2, 2).unwrap()); + assert_eq!(vec![VALUES4[2], node2], tree.get_path(2, 3).unwrap()); + + // check depth 1 + assert_eq!(vec![node3], tree.get_path(1, 0).unwrap()); + assert_eq!(vec![node2], tree.get_path(1, 1).unwrap()); +} + +#[test] +fn update_leaf() { + let mut tree = SimpleSmt::new(KEYS8.into_iter().zip(VALUES8.into_iter()), 3).unwrap(); + + // update one value + let key = 3; + let new_node = int_to_node(9); + let mut expected_values = VALUES8.to_vec(); + expected_values[key] = new_node; + let expected_tree = SimpleSmt::new( + KEYS8.into_iter().zip(expected_values.clone().into_iter()), + 3, + ) + .unwrap(); + + tree.update_leaf(key as u64, new_node).unwrap(); + assert_eq!(expected_tree.root, tree.root); + + // update another value + let key = 6; + let new_node = int_to_node(10); + expected_values[key] = new_node; + let expected_tree = + SimpleSmt::new(KEYS8.into_iter().zip(expected_values.into_iter()), 3).unwrap(); + + tree.update_leaf(key as u64, new_node).unwrap(); + assert_eq!(expected_tree.root, tree.root); +} + +#[test] +fn small_tree_opening_is_consistent() { + // ____k____ + // / \ + // _i_ _j_ + // / \ / \ + // e f g h + // / \ / \ / \ / \ + // a b 0 0 c 0 0 d + + let z = Word::from(RpoDigest::default()); + + let a = Word::from(Rpo256::merge(&[z.into(); 2])); + let b = Word::from(Rpo256::merge(&[a.into(); 2])); + let c = Word::from(Rpo256::merge(&[b.into(); 2])); + let d = Word::from(Rpo256::merge(&[c.into(); 2])); + + let e = Word::from(Rpo256::merge(&[a.into(), b.into()])); + let f = Word::from(Rpo256::merge(&[z.into(), z.into()])); + let g = Word::from(Rpo256::merge(&[c.into(), z.into()])); + let h = Word::from(Rpo256::merge(&[z.into(), d.into()])); + + let i = Word::from(Rpo256::merge(&[e.into(), f.into()])); + let j = Word::from(Rpo256::merge(&[g.into(), h.into()])); + + let k = Word::from(Rpo256::merge(&[i.into(), j.into()])); + + let depth = 3; + let entries = vec![(0, a), (1, b), (4, c), (7, d)]; + let tree = SimpleSmt::new(entries, depth).unwrap(); + + assert_eq!(tree.root(), Word::from(k)); + + let cases: Vec<(u32, u64, Vec)> = vec![ + (3, 0, vec![b, f, j]), + (3, 1, vec![a, f, j]), + (3, 4, vec![z, h, i]), + (3, 7, vec![z, g, i]), + (2, 0, vec![f, j]), + (2, 1, vec![e, j]), + (2, 2, vec![h, i]), + (2, 3, vec![g, i]), + (1, 0, vec![j]), + (1, 1, vec![i]), + ]; + + for (depth, key, path) in cases { + let opening = tree.get_path(depth, key).unwrap(); + + assert_eq!(path, opening); + } +} + +proptest! { + #[test] + fn arbitrary_openings_single_leaf( + depth in SimpleSmt::MIN_DEPTH..SimpleSmt::MAX_DEPTH, + key in prop::num::u64::ANY, + leaf in prop::num::u64::ANY, + ) { + let mut tree = SimpleSmt::new(iter::empty(), depth).unwrap(); + + let key = key % (1 << depth as u64); + let leaf = int_to_node(leaf); + + tree.insert_leaf(key, leaf.into()).unwrap(); + tree.get_leaf_path(key).unwrap(); + + // traverse to root, fetching all paths + for d in 1..depth { + let k = key >> (depth - d); + tree.get_path(d, k).unwrap(); + } + } + + #[test] + fn arbitrary_openings_multiple_leaves( + depth in SimpleSmt::MIN_DEPTH..SimpleSmt::MAX_DEPTH, + count in 2u8..10u8, + ref seed in any::<[u8; 32]>() + ) { + let mut tree = SimpleSmt::new(iter::empty(), depth).unwrap(); + let mut seed = *seed; + let leaves = (1 << depth) - 1; + + for _ in 0..count { + seed = prng_array(seed); + + let mut key = [0u8; 8]; + let mut leaf = [0u8; 8]; + + key.copy_from_slice(&seed[..8]); + leaf.copy_from_slice(&seed[8..16]); + + let key = u64::from_le_bytes(key); + let key = key % leaves; + let leaf = u64::from_le_bytes(leaf); + let leaf = int_to_node(leaf); + + tree.insert_leaf(key, leaf).unwrap(); + tree.get_leaf_path(key).unwrap(); + } + } +} + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn compute_internal_nodes() -> (Word, Word, Word) { + let node2 = Rpo256::hash_elements(&[VALUES4[0], VALUES4[1]].concat()); + let node3 = Rpo256::hash_elements(&[VALUES4[2], VALUES4[3]].concat()); + let root = Rpo256::merge(&[node2, node3]); + + (root.into(), node2.into(), node3.into()) +} + +const fn int_to_node(value: u64) -> Word { + [Felt::new(value), Felt::ZERO, Felt::ZERO, Felt::ZERO] +}