Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: quality-of-life improvements for HasherMerkleTree and bytes_to_field_elements #335

Merged
merged 7 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plonk/src/circuit/plonk_verifier/gadgets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ where
}
let mut transcript_var = RescueTranscriptVar::new(circuit);
if let Some(msg) = extra_transcript_init_msg {
let msg_fs = bytes_to_field_elements::<_, F>(msg);
let msg_fs = bytes_to_field_elements::<_, F>(msg.as_ref());
let msg_vars = msg_fs
.iter()
.map(|x| circuit.create_variable(*x))
Expand Down
2 changes: 1 addition & 1 deletion plonk/src/circuit/transcript.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ mod tests {
for _ in 0..10 {
for i in 0..10 {
let msg = format!("message {}", i);
let vals = bytes_to_field_elements(&msg);
let vals = bytes_to_field_elements(msg.as_ref());
let message_vars: Vec<Variable> = vals
.iter()
.map(|x| circuit.create_variable(*x).unwrap())
Expand Down
67 changes: 65 additions & 2 deletions primitives/src/merkle_tree/hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,78 @@ pub type HasherMerkleTree<H, E> = GenericHasherMerkleTree<H, E, u64, U3>;
pub type GenericHasherMerkleTree<H, E, I, Arity> =
MerkleTree<E, HasherDigestAlgorithm, I, Arity, HasherNode<H>>;

/// Convenience trait and blanket impl for downstream trait bounds.
///
/// Useful for downstream code that's generic ofer [`Digest`] hasher `H`.
///
/// # Example
///
/// Do this:
/// ```
/// # use jf_primitives::merkle_tree::{hasher::HasherMerkleTree, MerkleTreeScheme};
/// # use jf_primitives::merkle_tree::hasher::HasherDigest;
/// fn generic_over_hasher<H>()
/// where
/// H: HasherDigest,
/// {
/// let my_data = [1, 2, 3, 4, 5, 6, 7, 8, 9];
/// let mt = HasherMerkleTree::<H, usize>::from_elems(2, &my_data).unwrap();
/// }
/// ```
///
/// Instead of this:
/// ```
/// # use digest::{crypto_common::generic_array::ArrayLength, Digest, OutputSizeUser};
/// # use ark_serialize::Write;
/// # use jf_primitives::merkle_tree::{hasher::HasherMerkleTree, MerkleTreeScheme};
/// # use jf_primitives::merkle_tree::hasher::HasherDigest;
/// fn generic_over_hasher<H>()
/// where
/// H: Digest + Write,
/// <<H as OutputSizeUser>::OutputSize as ArrayLength<u8>>::ArrayType: Copy,
/// {
/// let my_data = [1, 2, 3, 4, 5, 6, 7, 8, 9];
/// let mt = HasherMerkleTree::<H, usize>::from_elems(2, &my_data).unwrap();
/// }
/// ```
///
/// Note that the complex trait bound for [`Copy`] is necessary:
/// ```compile_fail
/// # use digest::{crypto_common::generic_array::ArrayLength, Digest, OutputSizeUser};
/// # use ark_serialize::Write;
/// # use jf_primitives::merkle_tree::{hasher::HasherMerkleTree, MerkleTreeScheme};
/// # use jf_primitives::merkle_tree::hasher::HasherDigest;
/// fn generic_over_hasher<H>()
/// where
/// H: Digest + Write,
/// {
/// let my_data = [1, 2, 3, 4, 5, 6, 7, 8, 9];
/// let mt = HasherMerkleTree::<H, usize>::from_elems(2, &my_data).unwrap();
/// }
/// ```
pub trait HasherDigest: Digest<OutputSize = Self::Foo> + Write {
/// Associated type needed to express trait bounds.
type Foo: ArrayLength<u8, ArrayType = Self::Bar>;
/// Associated type needed to express trait bounds.
type Bar: Copy;
}
impl<T> HasherDigest for T
where
T: Digest + Write,
<T::OutputSize as ArrayLength<u8>>::ArrayType: Copy,
{
type Foo = T::OutputSize;
type Bar = <<T as HasherDigest>::Foo as ArrayLength<u8>>::ArrayType;
}

/// A struct that impls [`DigestAlgorithm`] for use with [`MerkleTree`].
pub struct HasherDigestAlgorithm;

impl<E, I, H> DigestAlgorithm<E, I, HasherNode<H>> for HasherDigestAlgorithm
where
E: Element + CanonicalSerialize,
I: Index + CanonicalSerialize,
H: Digest + Write,
<<H as OutputSizeUser>::OutputSize as ArrayLength<u8>>::ArrayType: Copy,
H: HasherDigest,
{
fn digest(data: &[HasherNode<H>]) -> Result<HasherNode<H>, PrimitivesError> {
let mut hasher = H::new();
Expand Down
41 changes: 18 additions & 23 deletions utilities/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use ark_ec::CurveConfig;
use ark_ff::{BigInteger, Field, PrimeField};
use ark_std::{
borrow::Borrow,
cmp::min,
iter::{once, repeat},
mem,
Expand Down Expand Up @@ -127,41 +128,37 @@ where
/// If any of the above conditions holds then this function *always* panics.
pub fn bytes_to_field_elements<B, F>(bytes: B) -> Vec<F>
where
B: AsRef<[u8]>,
B: Borrow<[u8]>,
F: Field,
{
let bytes = bytes.borrow();
let (primefield_bytes_len, extension_degree, field_bytes_len) = compile_time_checks::<F>();
if bytes.as_ref().is_empty() {
if bytes.is_empty() {
return Vec::new();
}

// Result length is always less than `bytes` length for sufficiently large
// `bytes`. Thus, the following should never panic.
let result_len = (field_bytes_len
.checked_add(bytes.as_ref().len())
.checked_add(bytes.len())
.expect("result len should fit into usize")
- 1)
/ field_bytes_len
+ 1;

let result = once(F::from(bytes.as_ref().len() as u64)) // the first field element encodes the bytes length as u64
.chain(
bytes
.as_ref()
.chunks(field_bytes_len)
.map(|field_elem_bytes| {
F::from_base_prime_field_elems(
&field_elem_bytes.chunks(primefield_bytes_len)
let result = once(F::from(bytes.len() as u64)) // the first field element encodes the bytes length as u64
.chain(bytes.chunks(field_bytes_len).map(|field_elem_bytes| {
F::from_base_prime_field_elems(
&field_elem_bytes.chunks(primefield_bytes_len)
.map(F::BasePrimeField::from_le_bytes_mod_order)
// not enough prime field elems? fill remaining elems with zero
.chain(repeat(F::BasePrimeField::ZERO).take(
extension_degree - (field_elem_bytes.len()-1) / primefield_bytes_len - 1)
)
.collect::<Vec<_>>(),
)
.expect("failed to construct field element")
}),
)
)
.expect("failed to construct field element")
}))
.collect::<Vec<_>>();

// sanity check
Expand All @@ -186,18 +183,16 @@ where
/// length of the return `Vec<u8>` overflows `usize`.
pub fn bytes_from_field_elements<T, F>(elems: T) -> Vec<u8>
where
T: AsRef<[F]>,
T: Borrow<[F]>,
F: Field,
{
let elems = elems.borrow();
let (primefield_bytes_len, _, field_bytes_len) = compile_time_checks::<F>();
if elems.as_ref().is_empty() {
if elems.is_empty() {
return Vec::new();
}

let (first_elem, elems) = elems
.as_ref()
.split_first()
.expect("elems should be non-empty");
let (first_elem, elems) = elems.split_first().expect("elems should be non-empty");

// the first element encodes the number of bytes to return
let result_len = usize::try_from(u64::from_le_bytes(
Expand Down Expand Up @@ -355,7 +350,7 @@ mod tests {
bytes[len..].fill(0);

// round trip
let encoded_bytes: Vec<F> = bytes_to_field_elements(&bytes);
let encoded_bytes: Vec<F> = bytes_to_field_elements(bytes.as_ref());
let result = bytes_from_field_elements(encoded_bytes);
assert_eq!(result, bytes);
}
Expand All @@ -364,7 +359,7 @@ mod tests {
// with random field elements
elems.resize(len, F::zero());
elems.iter_mut().for_each(|e| *e = F::rand(&mut rng));
bytes_from_field_elements(&elems);
bytes_from_field_elements(elems.as_ref());
}
}

Expand Down
Loading