Skip to content

Commit

Permalink
refactor expander-dedicated mle
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyong1997 committed Feb 25, 2025
1 parent 974e3b7 commit c12cfbd
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 119 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions arith/polynomials/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ edition = "2021"

[dependencies]
arith = { path = "../" }
gkr_field_config = { path = "../../config/gkr_field_config" }
mpi_config = { path = "../../config/mpi_config" }

ark-std.workspace = true
criterion.workspace = true
Expand Down
134 changes: 132 additions & 2 deletions arith/polynomials/src/mle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use std::ops::{Index, IndexMut, Mul};
use std::{
cmp,
marker::PhantomData,
ops::{Index, IndexMut, Mul},
};

use arith::Field;
use arith::{Field, SimdField};
use ark_std::{log2, rand::RngCore};
use gkr_field_config::GKRFieldConfig;
use mpi_config::MPIConfig;

use crate::{EqPolynomial, MultilinearExtension, MutableMultilinearExtension};

Expand Down Expand Up @@ -205,3 +211,127 @@ impl<F: Field> MutableMultilinearExtension<F> for MultiLinearPoly<F> {
}
}
}

#[derive(Debug, Clone, Default)]
pub struct MultiLinearPolyExpander<C: GKRFieldConfig> {
_config: PhantomData<C>,
}

/// Some dedicated mle implementations for GKRFieldConfig
/// Take into consideration the simd challenge and the mpi challenge
///
/// This is more efficient than the generic implementation by avoiding
/// unnecessary conversions between field types
impl<C: GKRFieldConfig> MultiLinearPolyExpander<C> {
pub fn new() -> Self {
Self {
_config: PhantomData,
}
}

#[inline]
pub fn eval_circuit_vals_at_challenge(
evals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
scratch: &mut [C::Field],
) -> C::Field {
assert_eq!(1 << x.len(), evals.len());
assert!(scratch.len() >= evals.len());

if x.is_empty() {
C::simd_circuit_field_into_field(&evals[0])
} else {
for i in 0..(evals.len() >> 1) {
scratch[i] = C::field_add_simd_circuit_field(
&C::simd_circuit_field_mul_challenge_field(
&(evals[i * 2 + 1] - evals[i * 2]),
&x[0],
),
&evals[i * 2],
);
}

let mut cur_eval_size = evals.len() >> 2;
for r in x.iter().skip(1) {
for i in 0..cur_eval_size {
scratch[i] = scratch[i * 2] + (scratch[i * 2 + 1] - scratch[i * 2]).scale(r);
}
cur_eval_size >>= 1;
}
scratch[0]
}
}

/// This assumes each mpi core hold their own evals, and collectively
/// compute the global evaluation.
/// Mostly used by the prover run with `mpiexec`
#[inline]
pub fn collectively_eval_circuit_vals_at_expander_challenge(
local_evals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
x_simd: &[C::ChallengeField],
x_mpi: &[C::ChallengeField],
scratch_field: &mut [C::Field],
scratch_challenge_field: &mut [C::ChallengeField],
mpi_config: &MPIConfig,
) -> C::ChallengeField {
assert!(scratch_challenge_field.len() >= 1 << cmp::max(x_simd.len(), x_mpi.len()));

let local_simd = Self::eval_circuit_vals_at_challenge(local_evals, x, scratch_field);
let local_simd_unpacked = local_simd.unpack();
let local_v = MultiLinearPoly::evaluate_with_buffer(
&local_simd_unpacked,
x_simd,
scratch_challenge_field,
);

let global_v = if mpi_config.is_root() {
let mut claimed_v_gathering_buffer =
vec![C::ChallengeField::zero(); mpi_config.world_size()];
mpi_config.gather_vec(&vec![local_v], &mut claimed_v_gathering_buffer);
MultiLinearPoly::evaluate_with_buffer(
&claimed_v_gathering_buffer,
&x_mpi,
scratch_challenge_field,
)
} else {
mpi_config.gather_vec(&vec![local_v], &mut vec![]);
C::ChallengeField::zero()
};

global_v
}

/// This assumes only a single core holds all the evals, and evaluate it locally
/// mostly used by the verifier
#[inline]
pub fn single_core_eval_circuit_vals_at_expander_challenge(
global_vals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
x_simd: &[C::ChallengeField],
x_mpi: &[C::ChallengeField],
) -> C::ChallengeField {
let local_poly_size = global_vals.len() >> x_mpi.len();
assert_eq!(local_poly_size, 1 << x.len());

let mut scratch_field = vec![C::Field::default(); local_poly_size];
let mut scratch_challenge_field =
vec![C::ChallengeField::default(); 1 << cmp::max(x_simd.len(), x_mpi.len())];
let local_evals = global_vals
.chunks(local_poly_size)
.map(|local_vals| {
let local_simd =
Self::eval_circuit_vals_at_challenge(local_vals, x, &mut scratch_field);
let local_simd_unpacked = local_simd.unpack();
MultiLinearPoly::evaluate_with_buffer(
&local_simd_unpacked,
x_simd,
&mut scratch_challenge_field,
)
})
.collect::<Vec<C::ChallengeField>>();

let mut scratch = vec![C::ChallengeField::default(); local_evals.len()];
MultiLinearPoly::evaluate_with_buffer(&local_evals, x_mpi, &mut scratch)
}
}
33 changes: 0 additions & 33 deletions config/gkr_field_config/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,37 +73,4 @@ pub trait GKRFieldConfig: Default + Debug + Clone + Send + Sync + 'static {
fn get_field_pack_size() -> usize {
Self::SimdCircuitField::PACK_SIZE
}

/// Evaluate the circuit values at the challenge
#[inline]
fn eval_circuit_vals_at_challenge(
evals: &[Self::SimdCircuitField],
x: &[Self::ChallengeField],
scratch: &mut [Self::Field],
) -> Self::Field {
assert_eq!(1 << x.len(), evals.len());

if x.is_empty() {
Self::simd_circuit_field_into_field(&evals[0])
} else {
for i in 0..(evals.len() >> 1) {
scratch[i] = Self::field_add_simd_circuit_field(
&Self::simd_circuit_field_mul_challenge_field(
&(evals[i * 2 + 1] - evals[i * 2]),
&x[0],
),
&evals[i * 2],
);
}

let mut cur_eval_size = evals.len() >> 2;
for r in x.iter().skip(1) {
for i in 0..cur_eval_size {
scratch[i] = scratch[i * 2] + (scratch[i * 2 + 1] - scratch[i * 2]).scale(r);
}
cur_eval_size >>= 1;
}
scratch[0]
}
}
}
8 changes: 6 additions & 2 deletions crosslayer_prototype/src/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arith::{Field, SimdField};
use gkr_field_config::GKRFieldConfig;
use polynomials::MultiLinearPoly;
use polynomials::{MultiLinearPoly, MultiLinearPolyExpander};
use transcript::Transcript;

use crate::sumcheck::{sumcheck_prove_gather_layer, sumcheck_prove_scatter_layer};
Expand All @@ -18,7 +18,11 @@ pub fn prove_gkr<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
.generate_challenge_field_elements(final_layer_vals.len().trailing_zeros() as usize);
let r_simd = transcript
.generate_challenge_field_elements(C::get_field_pack_size().trailing_zeros() as usize);
let output_claim = C::eval_circuit_vals_at_challenge(final_layer_vals, &rz0, &mut sp.v_evals);
let output_claim = MultiLinearPolyExpander::<C>::eval_circuit_vals_at_challenge(
final_layer_vals,
&rz0,
&mut sp.v_evals,
);
let output_claim = MultiLinearPoly::<C::ChallengeField>::evaluate_with_buffer(
&output_claim.unpack(),
&r_simd,
Expand Down
31 changes: 10 additions & 21 deletions gkr/src/prover/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use arith::{Field, SimdField};
use circuit::Circuit;
use gkr_field_config::GKRFieldConfig;
use mpi_config::MPIConfig;
use polynomials::MultiLinearPoly;
use polynomials::{MultiLinearPoly, MultiLinearPolyExpander};
use sumcheck::{sumcheck_prove_gkr_layer, ProverScratchPad};
use transcript::Transcript;
use utils::timer::Timer;
Expand Down Expand Up @@ -44,27 +44,16 @@ pub fn gkr_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
let mut alpha = None;

let output_vals = &circuit.layers.last().unwrap().output_vals;

let claimed_v_simd = C::eval_circuit_vals_at_challenge(output_vals, &rz0, &mut sp.hg_evals);
let claimed_v_local = MultiLinearPoly::<C::ChallengeField>::evaluate_with_buffer(
&claimed_v_simd.unpack(),
&r_simd,
&mut sp.eq_evals_at_r_simd0,
);

let claimed_v = if mpi_config.is_root() {
let mut claimed_v_gathering_buffer =
vec![C::ChallengeField::zero(); mpi_config.world_size()];
mpi_config.gather_vec(&vec![claimed_v_local], &mut claimed_v_gathering_buffer);
MultiLinearPoly::evaluate_with_buffer(
&claimed_v_gathering_buffer,
let claimed_v =
MultiLinearPolyExpander::<C>::collectively_eval_circuit_vals_at_expander_challenge(
output_vals,
&rz0,
&r_simd,
&r_mpi,
&mut sp.eq_evals_at_r_mpi0,
)
} else {
mpi_config.gather_vec(&vec![claimed_v_local], &mut vec![]);
C::ChallengeField::zero()
};
&mut sp.hg_evals,
&mut sp.eq_evals_first_half, // confusing name here..
mpi_config,
);

for i in (0..layer_num).rev() {
let timer = Timer::new(
Expand Down
31 changes: 11 additions & 20 deletions gkr/src/prover/gkr_square.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ark_std::{end_timer, start_timer};
use circuit::Circuit;
use gkr_field_config::GKRFieldConfig;
use mpi_config::MPIConfig;
use polynomials::MultiLinearPoly;
use polynomials::MultiLinearPolyExpander;
use sumcheck::{sumcheck_prove_gkr_square_layer, ProverScratchPad};
use transcript::Transcript;

Expand Down Expand Up @@ -42,26 +42,17 @@ pub fn gkr_square_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
}

let output_vals = &circuit.layers.last().unwrap().output_vals;
let claimed_v_simd = C::eval_circuit_vals_at_challenge(output_vals, &rz0, &mut sp.hg_evals);
let claimed_v_local = MultiLinearPoly::<C::ChallengeField>::evaluate_with_buffer(
&claimed_v_simd.unpack(),
&r_simd,
&mut sp.eq_evals_at_r_simd0,
);

let claimed_v = if mpi_config.is_root() {
let mut claimed_v_gathering_buffer =
vec![C::ChallengeField::zero(); mpi_config.world_size()];
mpi_config.gather_vec(&vec![claimed_v_local], &mut claimed_v_gathering_buffer);
MultiLinearPoly::evaluate_with_buffer(
&claimed_v_gathering_buffer,
let claimed_v =
MultiLinearPolyExpander::<C>::collectively_eval_circuit_vals_at_expander_challenge(
output_vals,
&rz0,
&r_simd,
&r_mpi,
&mut sp.eq_evals_at_r_mpi0,
)
} else {
mpi_config.gather_vec(&vec![claimed_v_local], &mut vec![]);
C::ChallengeField::zero()
};
&mut sp.hg_evals,
&mut sp.eq_evals_first_half, // confusing name here..
mpi_config,
);

log::trace!("Claimed v: {:?}", claimed_v);

for i in (0..layer_num).rev() {
Expand Down
15 changes: 8 additions & 7 deletions poly_commit/src/orion/simd_field_agg_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use gf2_128::GF2_128;
use gkr_field_config::{GF2ExtConfig, GKRFieldConfig, M31ExtConfig};
use itertools::izip;
use mersenne31::{M31Ext3, M31x16};
use polynomials::{EqPolynomial, MultiLinearPoly};
use polynomials::{EqPolynomial, MultiLinearPoly, MultiLinearPolyExpander};
use transcript::{BytesHashTranscript, Keccak256hasher, Transcript};

use crate::{
Expand Down Expand Up @@ -165,12 +165,13 @@ where
let aggregated_proof =
orion_proof_aggregate::<C, T>(&openings, &gkr_challenge.x_mpi, &mut aggregator_transcript);

let final_expected_eval = RawExpanderGKR::<C, T>::eval(
&global_poly.coeffs,
&gkr_challenge.x,
&gkr_challenge.x_simd,
&gkr_challenge.x_mpi,
);
let final_expected_eval =
MultiLinearPolyExpander::<C>::single_core_eval_circuit_vals_at_expander_challenge(
&global_poly.coeffs,
&gkr_challenge.x,
&gkr_challenge.x_simd,
&gkr_challenge.x_mpi,
);

assert!(orion_verify_simd_field_aggregated::<C, ComPackF, T>(
num_parties,
Expand Down
41 changes: 9 additions & 32 deletions poly_commit/src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use arith::{BN254Fr, ExtensionField, Field, FieldForECC, FieldSerde, FieldSerdeR
use ethnum::U256;
use gkr_field_config::GKRFieldConfig;
use mpi_config::MPIConfig;
use polynomials::{MultiLinearPoly, MultilinearExtension};
use polynomials::{MultiLinearPoly, MultiLinearPolyExpander, MultilinearExtension};
use rand::RngCore;
use transcript::Transcript;

Expand Down Expand Up @@ -231,36 +231,13 @@ impl<C: GKRFieldConfig, T: Transcript<C::ChallengeField>> PCSForExpanderGKR<C, T
) -> bool {
assert!(mpi_config.is_root()); // Only the root will verify
let ExpanderGKRChallenge::<C> { x, x_simd, x_mpi } = x;
Self::eval(&commitment.evals, x, x_simd, x_mpi) == v
}
}

impl<C: GKRFieldConfig, T: Transcript<C::ChallengeField>> RawExpanderGKR<C, T> {
pub fn eval_local(
vals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
x_simd: &[C::ChallengeField],
) -> C::ChallengeField {
let mut scratch = vec![C::Field::default(); vals.len()];
let y_simd = C::eval_circuit_vals_at_challenge(vals, x, &mut scratch);
let y_simd_unpacked = y_simd.unpack();
let mut scratch = vec![C::ChallengeField::default(); y_simd_unpacked.len()];
MultiLinearPoly::evaluate_with_buffer(&y_simd_unpacked, x_simd, &mut scratch)
}

pub fn eval(
vals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
x_simd: &[C::ChallengeField],
x_mpi: &[C::ChallengeField],
) -> C::ChallengeField {
let local_poly_size = vals.len() >> x_mpi.len();
let local_evals = vals
.chunks(local_poly_size)
.map(|local_vals| Self::eval_local(local_vals, x, x_simd))
.collect::<Vec<C::ChallengeField>>();

let mut scratch = vec![C::ChallengeField::default(); local_evals.len()];
MultiLinearPoly::evaluate_with_buffer(&local_evals, x_mpi, &mut scratch)
let v_target =
MultiLinearPolyExpander::<C>::single_core_eval_circuit_vals_at_expander_challenge(
&commitment.evals,
x,
x_simd,
x_mpi,
);
v == v_target
}
}
Loading

0 comments on commit c12cfbd

Please sign in to comment.