Skip to content

Commit

Permalink
Move to Jolt-based sumcheck (#129)
Browse files Browse the repository at this point in the history
Moves from our original product-based sumcheck implementation, to Jolt's
`comb_fn`-based implementation.

Handling the `comb_fn` is still a bit finicky. I've further explored
moving it to the `VirtualPolynomial` as a field, however using either
`dyn`amic dispatch or being `<`templated`>` is a bit awkward, mainly
because we are using it as a closure. However using it as a closure is
useful for us due to their versatility and advantage in capturing stuff.
Currently, as it stands provides the lowest impact in readability and
performance. Let me know if you guys have any other ideas on how to
handle this.

Also removes `VirtualPolynomial`.
  • Loading branch information
v0-e authored Dec 5, 2024
1 parent b7f2997 commit 181a6f1
Show file tree
Hide file tree
Showing 14 changed files with 473 additions and 1,123 deletions.
1 change: 0 additions & 1 deletion latticefold/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ parallel = [
"lattirust-ring/parallel",
]
getrandom = [ "ark-std/getrandom" ]
jolt-sumcheck = []

# dev-only
dhat-heap = []
Expand Down
80 changes: 12 additions & 68 deletions latticefold/src/nifs/folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ use super::error::FoldingError;
use crate::ark_base::*;
use crate::transcript::TranscriptWithShortChallenges;
use crate::utils::mle_helpers::evaluate_mles;
use crate::utils::sumcheck::{
virtual_polynomial::{eq_eval, VPAuxInfo},
MLSumcheck,
SumCheckError::SumCheckFailed,
};
use crate::utils::sumcheck::{utils::eq_eval, MLSumcheck, SumCheckError::SumCheckFailed};
use crate::{
arith::{error::CSError, Witness, CCS, LCCCS},
decomposition_parameters::DecompositionParams,
Expand All @@ -31,9 +27,6 @@ use crate::utils::sumcheck::prover::ProverState;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

#[cfg(feature = "jolt-sumcheck")]
use lattirust_ring::PolyRing;

#[cfg(test)]
mod tests;

Expand Down Expand Up @@ -167,7 +160,7 @@ impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> FoldingProver<NTT
Self::calculate_challenged_mz_mle(&mz_mles[0..P::K], &zeta_s[0..P::K])?;
let prechallenged_Ms_2 =
Self::calculate_challenged_mz_mle(&mz_mles[P::K..2 * P::K], &zeta_s[P::K..2 * P::K])?;
let g = create_sumcheck_polynomial::<_, P>(
let (g_mles, g_degree) = create_sumcheck_polynomial::<_, P>(
log_m,
&f_hat_mles,
&alpha_s,
Expand All @@ -178,63 +171,11 @@ impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> FoldingProver<NTT
&mu_s,
)?;

#[cfg(feature = "jolt-sumcheck")]
let comb_fn = |_: &ProverState<NTT>, vals: &[NTT]| -> NTT {
let extension_degree = NTT::CoefficientRepresentation::dimension() / NTT::dimension();

// Add eq_r * g1 * g3 for first k
let mut result = vals[0] * vals[1];

// Add eq_r * g1 * g3 for second k
result += vals[2] * vals[3];

// We have k * extension degree mles of b
// each one consists of (2 * small_b) -1 extensions
// We start at index 5
// Multiply each group of (2 * small_b) -1 extensions
// Then multiply by the eq_beta evaluation at index 4
for (k, mu) in mu_s.iter().enumerate() {
let mut inter_result = NTT::zero();
for d in (0..extension_degree).rev() {
let i = k * extension_degree + d;

let f_i = vals[5 + i];

if f_i.is_zero() {
inter_result *= mu;
continue;
}

// start with eq_b
let mut eval = vals[4];

let f_i_squared = f_i * f_i;

for b in 1..P::B_SMALL {
let multiplicand = f_i_squared - NTT::from(b as u128 * b as u128);
if multiplicand.is_zero() {
eval = NTT::zero();
break;
}
eval *= multiplicand
}
eval *= f_i;
inter_result += eval;
inter_result *= mu
}
result += inter_result;
}

result
};
let comb_fn = |vals: &[NTT]| -> NTT { sumcheck_polynomial_comb_fn::<NTT, P>(vals, &mu_s) };

// Step 5: Run sum check prover
let (sum_check_proof, prover_state) = MLSumcheck::prove_as_subprotocol(
transcript,
&g,
#[cfg(feature = "jolt-sumcheck")]
comb_fn,
);
let (sum_check_proof, prover_state) =
MLSumcheck::prove_as_subprotocol(transcript, &g_mles, log_m, g_degree, comb_fn);

let r_0 = Self::get_sumcheck_randomness(prover_state);

Expand Down Expand Up @@ -353,15 +294,17 @@ impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> LFFoldingVerifier

fn verify_sumcheck_proof(
transcript: &mut impl TranscriptWithShortChallenges<NTT>,
poly_info: &VPAuxInfo<NTT>,
nvars: usize,
degree: usize,
total_claim: NTT,
proof: &FoldingProof<NTT>,
) -> Result<(Vec<NTT>, NTT), FoldingError<NTT>> {
//Step 2: The sumcheck.
// Verify the sumcheck proof.
let sub_claim = MLSumcheck::verify_as_subprotocol(
transcript,
poly_info,
nvars,
degree,
total_claim,
&proof.pointshift_sumcheck_proof,
)?;
Expand Down Expand Up @@ -393,11 +336,12 @@ impl<NTT: SuitableRing, T: TranscriptWithShortChallenges<NTT>> FoldingVerifier<N
// Calculate claims for sumcheck verification
let (claim_g1, claim_g3) = Self::calculate_claims(&alpha_s, &zeta_s, cm_i_s);

let poly_info = VPAuxInfo::new(ccs.s, 2 * P::B_SMALL);
let nvars = ccs.s;
let degree = 2 * P::B_SMALL;

//Step 2: The sumcheck.
let (r_0, expected_evaluation) =
Self::verify_sumcheck_proof(transcript, &poly_info, claim_g1 + claim_g3, proof)?;
Self::verify_sumcheck_proof(transcript, nvars, degree, claim_g1 + claim_g3, proof)?;

// Verify evaluation claim
Self::verify_evaluation::<C, P>(
Expand Down
87 changes: 39 additions & 48 deletions latticefold/src/nifs/folding/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use crate::decomposition_parameters::test_params::{
use crate::nifs::folding::utils::SqueezeAlphaBetaZetaMu;
use crate::nifs::folding::{
prepare_public_output,
utils::{compute_v0_u0_x0_cm_0, create_sumcheck_polynomial, get_rhos},
utils::{
compute_v0_u0_x0_cm_0, create_sumcheck_polynomial, get_rhos, sumcheck_polynomial_comb_fn,
},
FoldingProver, FoldingVerifier, LFFoldingProver, LFFoldingVerifier,
};
use crate::nifs::FoldingProof;
use crate::transcript::{Transcript, TranscriptWithShortChallenges};
#[cfg(feature = "jolt-sumcheck")]
use crate::utils::sumcheck::prover::ProverState;
use crate::utils::sumcheck::{virtual_polynomial::VPAuxInfo, MLSumcheck};
use crate::utils::sumcheck::MLSumcheck;
use crate::{
arith::{r1cs::get_test_z_ntt_split, tests::get_test_ccs, Witness, CCCS},
commitment::AjtaiCommitmentScheme,
Expand Down Expand Up @@ -257,7 +257,7 @@ fn test_get_sumcheck_randomness() {
&zeta_s[DP::K..2 * DP::K],
)
.unwrap();
let g = create_sumcheck_polynomial::<_, DP>(
let (g_mles, g_degree) = create_sumcheck_polynomial::<_, DP>(
ccs.s,
&f_hat_mles,
&alpha_s,
Expand All @@ -269,24 +269,19 @@ fn test_get_sumcheck_randomness() {
)
.unwrap();

let comb_fn =
|vals: &[RqNTT]| -> RqNTT { sumcheck_polynomial_comb_fn::<RqNTT, DP>(vals, &mu_s) };

// Compute sumcheck proof
let (_, prover_state) = MLSumcheck::prove_as_subprotocol(
&mut transcript,
&g,
#[cfg(feature = "jolt-sumcheck")]
ProverState::combine_product,
);
let (_, prover_state) =
MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn);
// Derive randomness
let r_0 = LFFoldingProver::<RqNTT, PoseidonTranscript<RqNTT, CS>>::get_sumcheck_randomness(
prover_state,
);

// Validate - Check dimensions
assert_eq!(
r_0.len(),
g.aux_info.num_variables,
"Randomness r_0 has the wrong length"
);
assert_eq!(r_0.len(), ccs.s, "Randomness r_0 has the wrong length");
}

#[test]
Expand Down Expand Up @@ -317,7 +312,7 @@ fn test_get_thetas() {
&zeta_s[DP::K..2 * DP::K],
)
.unwrap();
let g = create_sumcheck_polynomial::<_, DP>(
let (g_mles, g_degree) = create_sumcheck_polynomial::<_, DP>(
ccs.s,
&f_hat_mles,
&alpha_s,
Expand All @@ -329,13 +324,11 @@ fn test_get_thetas() {
)
.unwrap();

let (_, prover_state) = MLSumcheck::prove_as_subprotocol(
&mut transcript,
&g,
#[cfg(feature = "jolt-sumcheck")]
ProverState::combine_product,
);
let comb_fn =
|vals: &[RqNTT]| -> RqNTT { sumcheck_polynomial_comb_fn::<RqNTT, DP>(vals, &mu_s) };

let (_, prover_state) =
MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn);
let r_0 = LFFoldingProver::<RqNTT, PoseidonTranscript<RqNTT, CS>>::get_sumcheck_randomness(
prover_state,
);
Expand Down Expand Up @@ -391,7 +384,7 @@ fn test_get_etas() {
&zeta_s[DP::K..2 * DP::K],
)
.unwrap();
let g = create_sumcheck_polynomial::<_, DP>(
let (g_mles, g_degree) = create_sumcheck_polynomial::<_, DP>(
ccs.s,
&f_hat_mles,
&alpha_s,
Expand All @@ -403,13 +396,11 @@ fn test_get_etas() {
)
.unwrap();

let (_, prover_state) = MLSumcheck::prove_as_subprotocol(
&mut transcript,
&g,
#[cfg(feature = "jolt-sumcheck")]
ProverState::combine_product,
);
let comb_fn =
|vals: &[RqNTT]| -> RqNTT { sumcheck_polynomial_comb_fn::<RqNTT, DP>(vals, &mu_s) };

let (_, prover_state) =
MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn);
let r_0 = LFFoldingProver::<RqNTT, PoseidonTranscript<RqNTT, CS>>::get_sumcheck_randomness(
prover_state,
);
Expand Down Expand Up @@ -493,7 +484,7 @@ fn test_prepare_public_output() {
&zeta_s[DP::K..2 * DP::K],
)
.unwrap();
let g = create_sumcheck_polynomial::<_, DP>(
let (g_mles, g_degree) = create_sumcheck_polynomial::<_, DP>(
ccs.s,
&f_hat_mles,
&alpha_s,
Expand All @@ -505,13 +496,11 @@ fn test_prepare_public_output() {
)
.unwrap();

let (_, prover_state) = MLSumcheck::prove_as_subprotocol(
&mut transcript,
&g,
#[cfg(feature = "jolt-sumcheck")]
ProverState::combine_product,
);
let comb_fn =
|vals: &[RqNTT]| -> RqNTT { sumcheck_polynomial_comb_fn::<RqNTT, DP>(vals, &mu_s) };

let (_, prover_state) =
MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn);
let r_0 = LFFoldingProver::<RqNTT, PoseidonTranscript<RqNTT, CS>>::get_sumcheck_randomness(
prover_state,
);
Expand Down Expand Up @@ -576,7 +565,7 @@ fn test_compute_f_0() {
&zeta_s[DP::K..2 * DP::K],
)
.unwrap();
let g = create_sumcheck_polynomial::<_, DP>(
let (g_mles, g_degree) = create_sumcheck_polynomial::<_, DP>(
ccs.s,
&f_hat_mles,
&alpha_s,
Expand All @@ -588,13 +577,11 @@ fn test_compute_f_0() {
)
.unwrap();

let (_, prover_state) = MLSumcheck::prove_as_subprotocol(
&mut transcript,
&g,
#[cfg(feature = "jolt-sumcheck")]
ProverState::combine_product,
);
let comb_fn =
|vals: &[RqNTT]| -> RqNTT { sumcheck_polynomial_comb_fn::<RqNTT, DP>(vals, &mu_s) };

let (_, prover_state) =
MLSumcheck::prove_as_subprotocol(&mut transcript, &g_mles, ccs.s, g_degree, comb_fn);
let r_0 = LFFoldingProver::<RqNTT, PoseidonTranscript<RqNTT, CS>>::get_sumcheck_randomness(
prover_state,
);
Expand Down Expand Up @@ -694,7 +681,8 @@ fn test_verify_evaluation() {

let (alpha_s, beta_s, zeta_s, mu_s) = transcript.squeeze_alpha_beta_zeta_mu::<DP>(ccs.s);

let poly_info = VPAuxInfo::new(ccs.s, 2 * DP::B_SMALL);
let nvars = ccs.s;
let degree = 2 * DP::B_SMALL;

let (claim_g1, claim_g3) =
LFFoldingVerifier::<RqNTT, PoseidonTranscript<RqNTT, CS>>::calculate_claims::<C>(
Expand All @@ -704,7 +692,8 @@ fn test_verify_evaluation() {
let (r_0, expected_evaluation) =
LFFoldingVerifier::<RqNTT, PoseidonTranscript<RqNTT, CS>>::verify_sumcheck_proof(
&mut transcript,
&poly_info,
nvars,
degree,
claim_g1 + claim_g3,
&proof,
)
Expand Down Expand Up @@ -773,7 +762,8 @@ fn test_verify_sumcheck_proof() {

let (alpha_s, _, zeta_s, _) = transcript.squeeze_alpha_beta_zeta_mu::<DP>(ccs.s);

let poly_info = VPAuxInfo::new(ccs.s, 2 * DP::B_SMALL);
let nvars = ccs.s;
let degree = 2 * DP::B_SMALL;

let (claim_g1, claim_g3) =
LFFoldingVerifier::<RqNTT, PoseidonTranscript<RqNTT, CS>>::calculate_claims::<C>(
Expand All @@ -782,7 +772,8 @@ fn test_verify_sumcheck_proof() {

let result = LFFoldingVerifier::<RqNTT, PoseidonTranscript<RqNTT, CS>>::verify_sumcheck_proof(
&mut transcript,
&poly_info,
nvars,
degree,
claim_g1 + claim_g3,
&proof,
);
Expand Down
Loading

0 comments on commit 181a6f1

Please sign in to comment.