Skip to content

Commit

Permalink
Slight restructure in folding subprotocol (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
ElijahVlasov authored Dec 18, 2024
1 parent 4c00b2a commit 1b1a3c7
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 156 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

312 changes: 156 additions & 156 deletions latticefold/src/nifs/folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,98 +36,6 @@ pub use structs::*;

mod structs;

fn prepare_public_output<const C: usize, NTT: SuitableRing>(
r_0: Vec<NTT>,
v_0: Vec<NTT>,
cm_0: Commitment<C, NTT>,
u_0: Vec<NTT>,
x_0: Vec<NTT>,
h: NTT,
) -> LCCCS<C, NTT> {
LCCCS {
r: r_0,
v: v_0,
cm: cm_0,
u: u_0,
x_w: x_0[0..x_0.len() - 1].to_vec(),
h,
}
}

impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> LFFoldingProver<NTT, T> {
fn setup_f_hat_mles(w_s: &mut [Witness<NTT>]) -> Vec<Vec<DenseMultilinearExtension<NTT>>> {
cfg_iter_mut!(w_s)
.map(|w| w.take_f_hat())
.collect::<Vec<Vec<DenseMultilinearExtension<NTT>>>>()
}

fn get_ris<const C: usize>(cm_i_s: &[LCCCS<C, NTT>]) -> Vec<Vec<NTT>> {
cm_i_s.iter().map(|cm_i| cm_i.r.clone()).collect::<Vec<_>>()
}

fn calculate_challenged_mz_mle(
Mz_mles_vec: &[Vec<DenseMultilinearExtension<NTT>>],
zeta_s: &[NTT],
) -> Result<DenseMultilinearExtension<NTT>, FoldingError<NTT>> {
let mut combined_mle: DenseMultilinearExtension<NTT> = DenseMultilinearExtension::zero();

zeta_s
.iter()
.zip(Mz_mles_vec)
.for_each(|(zeta_i, Mz_mles)| {
let mut mle: DenseMultilinearExtension<NTT> = DenseMultilinearExtension::zero();
for M in Mz_mles.iter().rev() {
mle += M;
mle *= *zeta_i;
}
combined_mle += mle;
});
Ok(combined_mle)
}

fn get_sumcheck_randomness(sumcheck_prover_state: ProverState<NTT>) -> Vec<NTT> {
sumcheck_prover_state
.randomness
.into_iter()
.map(|x| x.into())
.collect::<Vec<NTT>>()
}

fn get_thetas(
f_hat_mles: &[Vec<DenseMultilinearExtension<NTT>>],
r_0: &[NTT],
) -> Result<Vec<Vec<NTT>>, FoldingError<NTT>> {
let theta_s: Vec<Vec<NTT>> = cfg_iter!(f_hat_mles)
.map(|f_hat_row| evaluate_mles::<_, _, _, FoldingError<NTT>>(f_hat_row, r_0))
.collect::<Result<Vec<_>, _>>()?;

Ok(theta_s)
}

fn get_etas(
Mz_mles_vec: &[Vec<DenseMultilinearExtension<NTT>>],
r_0: &[NTT],
) -> Result<Vec<Vec<NTT>>, FoldingError<NTT>> {
let eta_s: Vec<Vec<NTT>> = cfg_iter!(Mz_mles_vec)
.map(|Mz_mles| evaluate_mles::<_, _, _, FoldingError<NTT>>(Mz_mles, r_0))
.collect::<Result<Vec<_>, _>>()?;

Ok(eta_s)
}

fn compute_f_0(rho_s: &[NTT], w_s: &[Witness<NTT>]) -> Vec<NTT> {
rho_s
.iter()
.zip(w_s)
.fold(vec![NTT::ZERO; w_s[0].f.len()], |acc, (&rho_i, w_i)| {
acc.into_iter()
.zip(w_i.f.iter())
.map(|(acc_j, w_ij)| acc_j + rho_i * w_ij)
.collect()
})
}
}

impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> FoldingProver<NTT, T>
for LFFoldingProver<NTT, T>
{
Expand Down Expand Up @@ -222,6 +130,144 @@ impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> FoldingProver<NTT
}
}

impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> FoldingVerifier<NTT, T>
for LFFoldingVerifier<NTT, T>
{
fn verify<const C: usize, P: DecompositionParams>(
cm_i_s: &[LCCCS<C, NTT>],
proof: &FoldingProof<NTT>,
transcript: &mut impl TranscriptWithShortChallenges<NTT>,
ccs: &CCS<NTT>,
) -> Result<LCCCS<C, NTT>, FoldingError<NTT>> {
sanity_check::<NTT, P>(ccs)?;

// Step 1: Generate alpha, zeta, mu, beta challenges and validate input
let (alpha_s, beta_s, zeta_s, mu_s) = transcript.squeeze_alpha_beta_zeta_mu::<P>(ccs.s);

// Calculate claims for sumcheck verification
let (claim_g1, claim_g3) = Self::calculate_claims(&alpha_s, &zeta_s, cm_i_s);

let nvars = ccs.s;
let degree = 2 * P::B_SMALL;

//Step 2: The sumcheck.
let (r_0, expected_evaluation) =
Self::verify_sumcheck_proof(transcript, nvars, degree, claim_g1 + claim_g3, proof)?;

// Verify evaluation claim
Self::verify_evaluation::<C, P>(
&alpha_s,
&beta_s,
&mu_s,
&zeta_s,
&r_0,
expected_evaluation,
proof,
cm_i_s,
)?;

// Step 5
proof
.theta_s
.iter()
.for_each(|thetas| transcript.absorb_slice(thetas));
proof
.eta_s
.iter()
.for_each(|etas| transcript.absorb_slice(etas));
let (rho_s_coeff, rho_s) = get_rhos::<_, _, P>(transcript);

// Step 6
let (v_0, cm_0, u_0, x_0) = compute_v0_u0_x0_cm_0(
&rho_s_coeff,
&rho_s,
&proof.theta_s,
cm_i_s,
&proof.eta_s,
ccs,
);

// Step 7: Compute f0 and Witness_0

let h = x_0.last().copied().ok_or(FoldingError::IncorrectLength)?;
Ok(prepare_public_output(r_0, v_0, cm_0, u_0, x_0, h))
}
}

impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> LFFoldingProver<NTT, T> {
fn setup_f_hat_mles(w_s: &mut [Witness<NTT>]) -> Vec<Vec<DenseMultilinearExtension<NTT>>> {
cfg_iter_mut!(w_s)
.map(|w| w.take_f_hat())
.collect::<Vec<Vec<DenseMultilinearExtension<NTT>>>>()
}

fn get_ris<const C: usize>(cm_i_s: &[LCCCS<C, NTT>]) -> Vec<Vec<NTT>> {
cm_i_s.iter().map(|cm_i| cm_i.r.clone()).collect::<Vec<_>>()
}

fn calculate_challenged_mz_mle(
Mz_mles_vec: &[Vec<DenseMultilinearExtension<NTT>>],
zeta_s: &[NTT],
) -> Result<DenseMultilinearExtension<NTT>, FoldingError<NTT>> {
let mut combined_mle: DenseMultilinearExtension<NTT> = DenseMultilinearExtension::zero();

zeta_s
.iter()
.zip(Mz_mles_vec)
.for_each(|(zeta_i, Mz_mles)| {
let mut mle: DenseMultilinearExtension<NTT> = DenseMultilinearExtension::zero();
for M in Mz_mles.iter().rev() {
mle += M;
mle *= *zeta_i;
}
combined_mle += mle;
});
Ok(combined_mle)
}

fn get_sumcheck_randomness(sumcheck_prover_state: ProverState<NTT>) -> Vec<NTT> {
sumcheck_prover_state
.randomness
.into_iter()
.map(|x| x.into())
.collect::<Vec<NTT>>()
}

fn get_thetas(
f_hat_mles: &[Vec<DenseMultilinearExtension<NTT>>],
r_0: &[NTT],
) -> Result<Vec<Vec<NTT>>, FoldingError<NTT>> {
let theta_s: Vec<Vec<NTT>> = cfg_iter!(f_hat_mles)
.map(|f_hat_row| evaluate_mles::<_, _, _, FoldingError<NTT>>(f_hat_row, r_0))
.collect::<Result<Vec<_>, _>>()?;

Ok(theta_s)
}

fn get_etas(
Mz_mles_vec: &[Vec<DenseMultilinearExtension<NTT>>],
r_0: &[NTT],
) -> Result<Vec<Vec<NTT>>, FoldingError<NTT>> {
let eta_s: Vec<Vec<NTT>> = cfg_iter!(Mz_mles_vec)
.map(|Mz_mles| evaluate_mles::<_, _, _, FoldingError<NTT>>(Mz_mles, r_0))
.collect::<Result<Vec<_>, _>>()?;

Ok(eta_s)
}

fn compute_f_0(rho_s: &[NTT], w_s: &[Witness<NTT>]) -> Vec<NTT> {
rho_s
.iter()
.zip(w_s)
.fold(vec![NTT::ZERO; w_s[0].f.len()], |acc, (&rho_i, w_i)| {
acc.into_iter()
.zip(w_i.f.iter())
.map(|(acc_j, w_ij)| acc_j + rho_i * w_ij)
.collect()
})
}
}

impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> LFFoldingVerifier<NTT, T> {
#[allow(clippy::too_many_arguments)]
fn verify_evaluation<const C: usize, P: DecompositionParams>(
Expand Down Expand Up @@ -327,70 +373,6 @@ impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> LFFoldingVerifier
}
}

impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> FoldingVerifier<NTT, T>
for LFFoldingVerifier<NTT, T>
{
fn verify<const C: usize, P: DecompositionParams>(
cm_i_s: &[LCCCS<C, NTT>],
proof: &FoldingProof<NTT>,
transcript: &mut impl TranscriptWithShortChallenges<NTT>,
ccs: &CCS<NTT>,
) -> Result<LCCCS<C, NTT>, FoldingError<NTT>> {
sanity_check::<NTT, P>(ccs)?;

// Step 1: Generate alpha, zeta, mu, beta challenges and validate input
let (alpha_s, beta_s, zeta_s, mu_s) = transcript.squeeze_alpha_beta_zeta_mu::<P>(ccs.s);

// Calculate claims for sumcheck verification
let (claim_g1, claim_g3) = Self::calculate_claims(&alpha_s, &zeta_s, cm_i_s);

let nvars = ccs.s;
let degree = 2 * P::B_SMALL;

//Step 2: The sumcheck.
let (r_0, expected_evaluation) =
Self::verify_sumcheck_proof(transcript, nvars, degree, claim_g1 + claim_g3, proof)?;

// Verify evaluation claim
Self::verify_evaluation::<C, P>(
&alpha_s,
&beta_s,
&mu_s,
&zeta_s,
&r_0,
expected_evaluation,
proof,
cm_i_s,
)?;

// Step 5
proof
.theta_s
.iter()
.for_each(|thetas| transcript.absorb_slice(thetas));
proof
.eta_s
.iter()
.for_each(|etas| transcript.absorb_slice(etas));
let (rho_s_coeff, rho_s) = get_rhos::<_, _, P>(transcript);

// Step 6
let (v_0, cm_0, u_0, x_0) = compute_v0_u0_x0_cm_0(
&rho_s_coeff,
&rho_s,
&proof.theta_s,
cm_i_s,
&proof.eta_s,
ccs,
);

// Step 7: Compute f0 and Witness_0

let h = x_0.last().copied().ok_or(FoldingError::IncorrectLength)?;
Ok(prepare_public_output(r_0, v_0, cm_0, u_0, x_0, h))
}
}

fn sanity_check<NTT: SuitableRing, DP: DecompositionParams>(
ccs: &CCS<NTT>,
) -> Result<(), FoldingError<NTT>> {
Expand All @@ -400,3 +382,21 @@ fn sanity_check<NTT: SuitableRing, DP: DecompositionParams>(

Ok(())
}

fn prepare_public_output<const C: usize, NTT: SuitableRing>(
r_0: Vec<NTT>,
v_0: Vec<NTT>,
cm_0: Commitment<C, NTT>,
u_0: Vec<NTT>,
x_0: Vec<NTT>,
h: NTT,
) -> LCCCS<C, NTT> {
LCCCS {
r: r_0,
v: v_0,
cm: cm_0,
u: u_0,
x_w: x_0[0..x_0.len() - 1].to_vec(),
h,
}
}

0 comments on commit 1b1a3c7

Please sign in to comment.