Skip to content

Commit

Permalink
feat: add simple sparse merkle tree
Browse files Browse the repository at this point in the history
This commit moves the previous implementation of `SparseMerkleTree` from
miden-core to this crate.

It also include a couple of new tests, a bench suite, and a couple of
minor fixes. The original API was preserved to maintain compatibility
with `AdviceTape`.

closes #21
  • Loading branch information
vlopes11 committed Dec 13, 2022
1 parent a41329f commit 340b554
Show file tree
Hide file tree
Showing 8 changed files with 675 additions and 11 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ edition = "2021"
name = "hash"
harness = false

[[bench]]
name = "smt"
harness = false

[features]
default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"]
std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"]
Expand Down
89 changes: 89 additions & 0 deletions benches/smt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use core::mem::swap;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use miden_crypto::{merkle::SimpleSmt, Felt, Word};
use rand_utils::prng_array;

fn smt_rpo(c: &mut Criterion) {
// parameters

const DEPTH: u32 = 16;
const LEAVES: u64 = ((1 << DEPTH) - 1) as u64;
const KEY: u64 = (LEAVES) >> 2;

let mut seed = [0u8; 32];

// setup trees

let mut trees: Vec<_> = [1, LEAVES / 2, LEAVES]
.into_iter()
.scan([0u8; 32], |seed, count| {
let tree = create_simple_smt::<DEPTH>(count, seed);
Some(tree)
})
.collect();

let leaf = generate_word(&mut seed);

// benchmarks

let mut insert = c.benchmark_group(format!("smt update_leaf(depth{DEPTH})"));

for tree in trees.iter_mut() {
let count = tree.leaves_count() as u64;
insert.bench_with_input(
format!("simple smt({count})"),
&(KEY % count.max(1), leaf),
|b, (key, leaf)| {
b.iter(|| {
tree.update_leaf(black_box(*key), black_box(*leaf)).unwrap();
});
},
);
}

insert.finish();

let mut path = c.benchmark_group(format!("smt get_leaf_path(depth{DEPTH})"));

for tree in trees.iter_mut() {
let count = tree.leaves_count() as u64;
path.bench_with_input(
format!("simple smt({count})"),
&(KEY % count.max(1)),
|b, key| {
b.iter(|| {
tree.get_leaf_path(black_box(*key)).unwrap();
});
},
);
}

path.finish();
}

criterion_group!(smt_group, smt_rpo);
criterion_main!(smt_group);

// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------

fn generate_word(seed: &mut [u8; 32]) -> Word {
swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[
Felt::new(nums[0]),
Felt::new(nums[1]),
Felt::new(nums[2]),
Felt::new(nums[3]),
]
}

fn create_simple_smt<const DEPTH: u32>(count: u64, seed: &mut [u8; 32]) -> SimpleSmt {
let entries: Vec<_> = (0..count)
.map(|i| {
let word = generate_word(seed);
(i, word)
})
.collect();
SimpleSmt::new(entries, DEPTH).unwrap()
}
3 changes: 1 addition & 2 deletions src/hash/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::{Felt, FieldElement, StarkField, ONE, ZERO};
use winter_crypto::{Digest, ElementHasher, Hasher};
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO};

pub mod blake;
pub mod rpo;
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod merkle;
// RE-EXPORTS
// ================================================================================================

pub use winter_crypto::{Digest, ElementHasher, Hasher};
pub use winter_math::{fields::f64::BaseElement as Felt, FieldElement, StarkField};

pub mod utils {
Expand All @@ -23,11 +24,14 @@ pub mod utils {
// ================================================================================================

/// A group of four field elements in the Miden base field.
pub type Word = [Felt; 4];
pub type Word = [Felt; WORD_SIZE];

// CONSTANTS
// ================================================================================================

/// Number of field elements in a word.
pub const WORD_SIZE: usize = 4;

/// Field element representing ZERO in the Miden base filed.
pub const ZERO: Felt = Felt::ZERO;

Expand Down
13 changes: 7 additions & 6 deletions src/merkle/merkle_tree.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Digest, Felt, MerkleError, Rpo256, Vec, Word};
use super::{Felt, MerkleError, Rpo256, RpoDigest, Vec, Word};
use crate::{utils::uninit_vector, FieldElement};
use core::slice;
use winter_math::log2;
Expand All @@ -22,7 +22,7 @@ impl MerkleTree {
pub fn new(leaves: Vec<Word>) -> Result<Self, MerkleError> {
let n = leaves.len();
if n <= 1 {
return Err(MerkleError::DepthTooSmall);
return Err(MerkleError::DepthTooSmall(n as u32));
} else if !n.is_power_of_two() {
return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
}
Expand All @@ -35,7 +35,8 @@ impl MerkleTree {
nodes[n..].copy_from_slice(&leaves);

// re-interpret nodes as an array of two nodes fused together
let two_nodes = unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [Digest; 2], n) };
let two_nodes =
unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [RpoDigest; 2], n) };

// calculate all internal tree nodes
for i in (1..n).rev() {
Expand Down Expand Up @@ -68,7 +69,7 @@ impl MerkleTree {
/// * The specified index not valid for the specified depth.
pub fn get_node(&self, depth: u32, index: u64) -> Result<Word, MerkleError> {
if depth == 0 {
return Err(MerkleError::DepthTooSmall);
return Err(MerkleError::DepthTooSmall(depth));
} else if depth > self.depth() {
return Err(MerkleError::DepthTooBig(depth));
}
Expand All @@ -89,7 +90,7 @@ impl MerkleTree {
/// * The specified index not valid for the specified depth.
pub fn get_path(&self, depth: u32, index: u64) -> Result<Vec<Word>, MerkleError> {
if depth == 0 {
return Err(MerkleError::DepthTooSmall);
return Err(MerkleError::DepthTooSmall(depth));
} else if depth > self.depth() {
return Err(MerkleError::DepthTooBig(depth));
}
Expand Down Expand Up @@ -123,7 +124,7 @@ impl MerkleTree {

let n = self.nodes.len() / 2;
let two_nodes =
unsafe { slice::from_raw_parts(self.nodes.as_ptr() as *const [Digest; 2], n) };
unsafe { slice::from_raw_parts(self.nodes.as_ptr() as *const [RpoDigest; 2], n) };

for _ in 0..depth {
index /= 2;
Expand Down
36 changes: 34 additions & 2 deletions src/merkle/mod.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,61 @@
use super::{
hash::rpo::{Rpo256, RpoDigest as Digest},
hash::rpo::{Rpo256, RpoDigest},
utils::collections::{BTreeMap, Vec},
Felt, Word, ZERO,
};
use core::fmt;

mod merkle_tree;
pub use merkle_tree::MerkleTree;

mod merkle_path_set;
pub use merkle_path_set::MerklePathSet;

mod simple_smt;
pub use simple_smt::SimpleSmt;

// ERRORS
// ================================================================================================

#[derive(Clone, Debug)]
pub enum MerkleError {
DepthTooSmall,
DepthTooSmall(u32),
DepthTooBig(u32),
NumLeavesNotPowerOfTwo(usize),
InvalidIndex(u32, u64),
InvalidDepth(u32, u32),
InvalidPath(Vec<Word>),
InvalidEntriesCount(usize, usize),
NodeNotInSet(u64),
}

impl fmt::Display for MerkleError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use MerkleError::*;
match self {
DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"),
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
NumLeavesNotPowerOfTwo(leaves) => {
write!(f, "the leaves count {leaves} is not a power of 2")
}
InvalidIndex(depth, index) => write!(
f,
"the leaf index {index} is not valid for the depth {depth}"
),
InvalidDepth(expected, provided) => write!(
f,
"the provided depth {provided} is not valid for {expected}"
),
InvalidPath(_path) => write!(f, "the provided path is not valid"),
InvalidEntriesCount(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"),
NodeNotInSet(index) => write!(f, "the node indexed by {index} is not in the set"),
}
}
}

#[cfg(feature = "std")]
impl std::error::Error for MerkleError {}

// HELPER FUNCTIONS
// ================================================================================================

Expand Down
Loading

0 comments on commit 340b554

Please sign in to comment.