From 4ebf95fc8908877d9fe0bc78f382a6eff63b3d1f Mon Sep 17 00:00:00 2001 From: Davidson Souza Date: Fri, 22 Mar 2024 11:39:03 -0300 Subject: [PATCH] refactor calculate_hashes --- src/accumulator/proof.rs | 121 +++++++++++++++++---------------------- 1 file changed, 51 insertions(+), 70 deletions(-) diff --git a/src/accumulator/proof.rs b/src/accumulator/proof.rs index 19d40db..c0d8836 100644 --- a/src/accumulator/proof.rs +++ b/src/accumulator/proof.rs @@ -358,6 +358,9 @@ impl Proof { let mut calculated_root_hashes = Vec::::with_capacity(util::num_roots(num_leaves)); + // the positions that should be passed as a proof + let proof_positions = get_proof_positions(&self.targets, num_leaves, total_rows); + // As we calculate nodes upwards, it accumulates here let mut nodes: Vec<_> = self .targets @@ -366,80 +369,59 @@ impl Proof { .zip(del_hashes.to_owned()) .collect(); + // add the proof positions to the nodes + nodes.extend( + proof_positions + .iter() + .copied() + .zip(self.hashes.iter().copied()), + ); + // Nodes must be sorted for finding siblings during hashing nodes.sort(); + let mut i = 0; + while i < nodes.len() { + let (pos1, hash1) = nodes[i]; + let next_to_prove = util::parent(pos1, total_rows); + + // If the current position is a root, we add that to our result and don't go any further + if util::is_root_position(pos1, num_leaves, total_rows) { + calculated_root_hashes.push(hash1); + i += 1; + continue; + } - // An iterator over proof hashes - let mut hashes_iter = self.hashes.iter(); - - for row in 0..=total_rows { - // An iterator that only contains nodes of the current row - // We can't use a iterator over nodes, because we also need no mutable borrow it, - // clippy will suggest to use nodes.iter().cloned, but this will cause row_nodes to - // immutably borrow nodes. - #[allow(clippy::unnecessary_to_owned)] - let mut row_nodes = nodes - .to_owned() - .into_iter() - .filter(|x| util::detect_row(x.0, total_rows) == row) - .peekable(); - - while let Some((pos, hash)) = row_nodes.next() { - let next_to_prove = util::parent(pos, total_rows); - // If the current position is a root, we add that to our result and don't go any further - if util::is_root_position(pos, num_leaves, total_rows) { - calculated_root_hashes.push(hash); - continue; - } - - if let Some((next_pos, next_hash)) = row_nodes.peek() { - // Is the next node our sibling? If so, we should be hashed together - if util::is_right_sibling(pos, *next_pos) { - // There are three possible cases: the current hash is null, - // and the sibling is present, we push the sibling to targets. - // If The sibling is null, we push the current node. - // If none of them is null, we compute the parent hash of both siblings - // and push this to the next target. - if hash.is_empty() { - Proof::sorted_push(&mut nodes, (next_to_prove, *next_hash)); - } else if next_hash.is_empty() { - Proof::sorted_push(&mut nodes, (next_to_prove, hash)); - } else { - let hash = NodeHash::parent_hash(&hash, next_hash); - - Proof::sorted_push(&mut nodes, (next_to_prove, hash)); - } - - // Since we consumed 2 elements from nodes, skip one more here - // We need make this explicitly because peek, by definition - // does not advance the iterator. - row_nodes.next(); - - continue; - } - } + let Some((pos2, hash2)) = nodes.get(i + 1) else { + return Err(format!( + "Proof is too short. Expected at least {} elements, got {}", + i + 1, + nodes.len() + )); + }; - // If the next node is not my sibling, the hash must be passed inside the proof - if let Some(next_proof_hash) = hashes_iter.next() { - if !hash.is_empty() { - let hash = if util::is_left_niece(pos) { - NodeHash::parent_hash(&hash, next_proof_hash) - } else { - NodeHash::parent_hash(next_proof_hash, &hash) - }; - - Proof::sorted_push(&mut nodes, (next_to_prove, hash)); - continue; - } else { - // If none of the above, push a null hash upwards - Proof::sorted_push(&mut nodes, (next_to_prove, *next_proof_hash)); - } - } else { - return Err(String::from("Proof too short")); - } + if pos1 != util::left_sibling(*pos2) { + return Err(format!( + "Invalid proof. Expected left sibling of {} to be {}, got {}", + pos2, + util::left_sibling(*pos2), + pos1 + )); } + + let parent_hash = match (hash1.is_empty(), hash2.is_empty()) { + (true, true) => NodeHash::empty(), + (true, false) => *hash2, + (false, true) => hash1, + (false, false) => NodeHash::parent_hash(&hash1, hash2), + }; + + Self::sorted_push(&mut nodes, (next_to_prove, parent_hash)); + i += 2; } + // we shouldn't return the hashes in the proof + nodes.retain(|(pos, _)| !proof_positions.contains(pos)); + Ok((nodes, calculated_root_hashes)) } /// Uses the data passed in to update a proof, creating a valid proof for a given @@ -1084,7 +1066,6 @@ mod tests { // Make sure we got the expect roots assert_eq!(roots, expected_roots); - // Did we compute all expected nodes? assert_eq!(nodes.len(), expected_computed.len()); // For each calculated position, check if the position and hashes are as expected @@ -1243,7 +1224,7 @@ mod bench { .map(|&preimage| hash_from_u8(preimage)) .collect::>(); - let (_, modified) = Stump::new().modify(&utxos, &[], &Proof::default()).unwrap(); + let (s, modified) = Stump::new().modify(&utxos, &[], &Proof::default()).unwrap(); let proof = Proof::default(); let (proof, cached_hashes) = proof .update(vec![], utxos.clone(), vec![], vec![0, 3, 5], modified) @@ -1253,7 +1234,7 @@ mod bench { .iter() .map(|&preimage| hash_from_u8(preimage)) .collect::>(); - let (_, modified) = Stump::new().modify(&utxos, &cached_hashes, &proof).unwrap(); + let (_, modified) = s.modify(&utxos, &cached_hashes, &proof).unwrap(); bencher.iter(move || { proof