Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for periodic columns in LogUp-GKR #307

Merged
merged 27 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8ac25c7
feat: math utilities needed for sum-check protocol
Al-Kindi-0 Aug 6, 2024
5e06378
feat: add sum-check prover and verifier
Al-Kindi-0 Aug 6, 2024
16389d6
tests: add sanity tests for utils
Al-Kindi-0 Aug 6, 2024
380aa1a
doc: document sumcheck_round
Al-Kindi-0 Aug 6, 2024
7a1a99e
feat: use SmallVec
Al-Kindi-0 Aug 7, 2024
1901066
docs: improve documentation of sum-check
Al-Kindi-0 Aug 8, 2024
8a57216
feat: add remaining functions for sum-check verifier
Al-Kindi-0 Aug 9, 2024
ff9e6fa
chore: move prover into sub-mod
Al-Kindi-0 Aug 9, 2024
7e24f8f
chore: remove utils mod
Al-Kindi-0 Aug 9, 2024
23044e8
chore: remove utils mod
Al-Kindi-0 Aug 9, 2024
ad0497d
chore: move logup evaluator trait to separate file
Al-Kindi-0 Aug 9, 2024
a0272ea
feat: add GKR backend for LogUp-GKR
Al-Kindi-0 Aug 9, 2024
7b8caff
chore: remove old way of handling Lagrange kernel
Al-Kindi-0 Aug 12, 2024
e2b8c12
wip: add s-column constraints
Al-Kindi-0 Aug 12, 2024
b813916
chore: correct header
Al-Kindi-0 Aug 12, 2024
492f247
wip
Al-Kindi-0 Aug 12, 2024
0d664e0
wip: add support for periodic columns in gkr backend
Al-Kindi-0 Aug 14, 2024
8617308
Merge branch 'logup-gkr' into al-gkr-periodic
Al-Kindi-0 Sep 3, 2024
ed781d8
chore: fix post merge issues
Al-Kindi-0 Sep 3, 2024
98c0e71
chore: fix issues
Al-Kindi-0 Sep 3, 2024
807aba1
doc: add comment about periodic values table
Al-Kindi-0 Sep 3, 2024
4e6d3ab
chore: address feedback
Al-Kindi-0 Sep 4, 2024
c93ec35
chore: fix concurrent portion
Al-Kindi-0 Sep 4, 2024
873345a
chore: address feedback
Al-Kindi-0 Sep 5, 2024
d94794a
chore: address feedback
Al-Kindi-0 Sep 9, 2024
b484e04
chore: remove unnecessary mut
Al-Kindi-0 Sep 10, 2024
dec6589
chore: remove unnecessary mut
Al-Kindi-0 Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions air/src/air/logup_gkr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,28 @@ pub trait LogUpGkrEvaluator: Clone + Sync {
) -> SColumnConstraint<E> {
SColumnConstraint::new(gkr_data, composition_coefficient)
}

/// Returns the periodic values used in the LogUp-GKR statement, either as base field element
/// during circuit evaluation or as extension field element during the run of sum-check for
/// the input layer.
fn build_periodic_values<E>(&self) -> PeriodicTable<E>
where
E: FieldElement<BaseField = Self::BaseField>,
{
let table = self
.get_oracles()
.iter()
.filter_map(|oracle| {
if let LogUpGkrOracle::PeriodicValue(values) = oracle {
Some(values.iter().map(|x| E::from(*x)).collect())
} else {
None
}
})
.collect();

PeriodicTable { table }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I would rewrite this as
let table: Vec<Vec<F>> = self
    .get_oracles()
    .iter()
    .filter_map(|oracle| {
        if let LogUpGkrOracle::PeriodicValue(values) = oracle {
            Some(values.into_iter().copied().map(F::from).collect())
        } else {
            None
        }
    })
    .collect();

PeriodicTable{ table }
  1. E is not used, and it is redundant, right? It seems like we could always make F::BaseField the basefield, and F the extension field (or F::BaseField and F both be the base field).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, switched to it
Indeed, the generic is redundant now, it is a leftover from a previous iteration. Removed now

}
}

#[derive(Clone, Default)]
Expand Down Expand Up @@ -229,3 +251,56 @@ pub enum LogUpGkrOracle<B: StarkField> {
/// must be a power of 2.
PeriodicValue(Vec<B>),
}

// PERIODIC COLUMNS FOR LOGUP
// =================================================================================================

/// Stores the periodic columns used in a LogUp-GKR statement.
///
/// Each stored periodic column is interpreted as a multi-linear extension polynomial of the column
/// with the given periodic values. Due to the periodic nature of the values, storing, binding of
/// an argument and evaluating the said multi-linear extension can be all done linearly in the size
/// of the smallest cycle defining the periodic values. Hence we only store the values of this
/// smallest cycle. The cycle is assumed throughout to be a power of 2.
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)]
pub struct PeriodicTable<E: FieldElement> {
pub table: Vec<Vec<E>>,
}

impl<E> PeriodicTable<E>
where
E: FieldElement,
{
pub fn new(table: Vec<Vec<E::BaseField>>) -> Self {
let table = table.iter().map(|col| col.iter().map(|x| E::from(*x)).collect()).collect();

Self { table }
}

pub fn num_columns(&self) -> usize {
self.table.len()
}

pub fn table(&self) -> &[Vec<E>] {
&self.table
}

pub fn fill_periodic_values_at(&self, row: usize, values: &mut [E]) {
self.table
.iter()
.zip(values.iter_mut())
.for_each(|(col, value)| *value = col[row % col.len()])
}

pub fn bind_least_significant_variable(&mut self, round_challenge: E) {
for col in self.table.iter_mut() {
if col.len() > 1 {
let num_evals = col.len() >> 1;
for i in 0..num_evals {
col[i] = col[i << 1] + round_challenge * (col[(i << 1) + 1] - col[i << 1]);
}
col.truncate(num_evals)
}
}
}
}
2 changes: 1 addition & 1 deletion air/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use logup_gkr::PhantomLogUpGkrEval;
pub use logup_gkr::{
LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame,
LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, LogUpGkrEvaluator,
LogUpGkrOracle,
LogUpGkrOracle, PeriodicTable,
};

mod coefficients;
Expand Down
4 changes: 2 additions & 2 deletions air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ pub use air::{
DeepCompositionCoefficients, EvaluationFrame, GkrData,
LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint,
LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements,
LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, TraceInfo,
TransitionConstraintDegree, TransitionConstraints,
LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable,
TraceInfo, TransitionConstraintDegree, TransitionConstraints,
};
7 changes: 5 additions & 2 deletions prover/src/logup_gkr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,21 @@ impl<E: FieldElement> EvaluatedCircuit<E> {
log_up_randomness: &[E],
) -> CircuitLayer<E> {
let num_fractions = evaluator.get_num_fractions();
let periodic_values = evaluator.build_periodic_values();

let mut input_layer_wires =
Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions);
let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols());

let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()];
let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()];
let mut numerators = vec![E::ZERO; num_fractions];
let mut denominators = vec![E::ZERO; num_fractions];
for i in 0..main_trace.main_segment().num_rows() {
let wires_from_trace_row = {
main_trace.read_main_frame(i, &mut main_frame);

evaluator.build_query(&main_frame, &[], &mut query);
periodic_values.fill_periodic_values_at(i, &mut periodic_values_row);
evaluator.build_query(&main_frame, &periodic_values_row, &mut query);

evaluator.evaluate_query(
&query,
Expand Down
24 changes: 17 additions & 7 deletions prover/src/logup_gkr/prover.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::vec::Vec;

use air::{LogUpGkrEvaluator, LogUpGkrOracle};
use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable};
use crypto::{ElementHasher, RandomCoin};
use math::FieldElement;
use sumcheck::{
Expand Down Expand Up @@ -75,11 +75,18 @@ pub fn prove_gkr<E: FieldElement>(
let (before_final_layer_proofs, gkr_claim) = prove_intermediate_layers(circuit, public_coin)?;

// build the MLEs of the relevant main trace columns
let main_trace_mls =
let (main_trace_mls, periodic_table) =
build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?;

let final_layer_proof =
prove_input_layer(evaluator, logup_randomness, main_trace_mls, gkr_claim, public_coin)?;
// run the GKR prover for the input layer
let final_layer_proof = prove_input_layer(
evaluator,
logup_randomness,
main_trace_mls,
periodic_table,
gkr_claim,
public_coin,
)?;

Ok(GkrCircuitProof {
circuit_outputs: CircuitOutput { numerators, denominators },
Expand All @@ -97,6 +104,7 @@ fn prove_input_layer<
evaluator: &impl LogUpGkrEvaluator<BaseField = E::BaseField>,
log_up_randomness: Vec<E>,
multi_linear_ext_polys: Vec<MultiLinearPoly<E>>,
periodic_table: PeriodicTable<E>,
claim: GkrClaim<E>,
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
transcript: &mut C,
) -> Result<FinalLayerProof<E>, GkrProverError> {
Expand All @@ -114,6 +122,7 @@ fn prove_input_layer<
r_batch,
log_up_randomness,
multi_linear_ext_polys,
periodic_table,
transcript,
)?;

Expand All @@ -125,8 +134,9 @@ fn prove_input_layer<
fn build_mls_from_main_trace_segment<E: FieldElement>(
oracles: &[LogUpGkrOracle<E::BaseField>],
main_trace: &ColMatrix<<E as FieldElement>::BaseField>,
) -> Result<Vec<MultiLinearPoly<E>>, GkrProverError> {
) -> Result<(Vec<MultiLinearPoly<E>>, PeriodicTable<E>), GkrProverError> {
let mut mls = vec![];
let mut periodic_values = vec![];

for oracle in oracles {
match oracle {
Expand All @@ -146,10 +156,10 @@ fn build_mls_from_main_trace_segment<E: FieldElement>(
let ml = MultiLinearPoly::from_evaluations(values);
mls.push(ml)
},
LogUpGkrOracle::PeriodicValue(_) => unimplemented!(),
LogUpGkrOracle::PeriodicValue(values) => periodic_values.push(values.to_vec()),
};
}
Ok(mls)
Ok((mls, PeriodicTable::new(periodic_values)))
}

/// Proves all GKR layers except for input layer.
Expand Down
15 changes: 12 additions & 3 deletions sumcheck/benches/sum_check_high_degree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use std::{marker::PhantomData, time::Duration};

use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle};
use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin};
use math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField};
Expand Down Expand Up @@ -37,7 +37,7 @@ fn sum_check_high_degree(c: &mut Criterion) {
)
},
|(
(claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4)),
(claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4), periodic_table),
evaluator,
logup_randomness,
transcript,
Expand All @@ -52,6 +52,7 @@ fn sum_check_high_degree(c: &mut Criterion) {
r_batch,
logup_randomness,
mls,
periodic_table,
&mut transcript,
)
},
Expand All @@ -76,21 +77,29 @@ fn setup_sum_check<E: FieldElement>(
MultiLinearPoly<E>,
MultiLinearPoly<E>,
),
PeriodicTable<E>,
) {
let n = 1 << log_size;
let table = MultiLinearPoly::from_evaluations(rand_vector(n));
let multiplicity = MultiLinearPoly::from_evaluations(rand_vector(n));
let values_0 = MultiLinearPoly::from_evaluations(rand_vector(n));
let values_1 = MultiLinearPoly::from_evaluations(rand_vector(n));
let values_2 = MultiLinearPoly::from_evaluations(rand_vector(n));
let periodic_table = PeriodicTable::default();

// this will not generate the correct claim with overwhelming probability but should be fine
// for benchmarking
let rand_pt: Vec<E> = rand_vector(log_size + 2);
let r_batch: E = rand_value();
let claim: E = rand_value();

(claim, r_batch, rand_pt, (table, multiplicity, values_0, values_1, values_2))
(
claim,
r_batch,
rand_pt,
(table, multiplicity, values_0, values_1, values_2),
periodic_table,
)
}

#[derive(Clone, Default)]
Expand Down
60 changes: 44 additions & 16 deletions sumcheck/src/prover/high_degree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use alloc::vec::Vec;

use air::LogUpGkrEvaluator;
use air::{LogUpGkrEvaluator, PeriodicTable};
use crypto::{ElementHasher, RandomCoin};
use math::FieldElement;
#[cfg(feature = "concurrent")]
Expand Down Expand Up @@ -160,6 +160,7 @@ pub fn sum_check_prove_higher_degree<
r_sum_check: E,
log_up_randomness: Vec<E>,
mut mls: Vec<MultiLinearPoly<E>>,
mut periodic_table: PeriodicTable<E>,
coin: &mut impl RandomCoin<Hasher = H, BaseField = E::BaseField>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
coin: &mut impl RandomCoin<Hasher = H, BaseField = E::BaseField>,
mut periodic_table: PeriodicTable<E>,

We can pass this to sumcheck_round() as &mut periodic_table

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

) -> Result<SumCheckProof<E>, SumCheckProverError> {
let num_rounds = mls[0].num_variables();
Expand All @@ -176,8 +177,15 @@ pub fn sum_check_prove_higher_degree<
let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim };

// run the first round of the protocol
let round_poly_evals =
sumcheck_round(&eq_mu, evaluator, &eq_nu, &mls, &log_up_randomness, r_sum_check);
let round_poly_evals = sumcheck_round(
&eq_mu,
evaluator,
&eq_nu,
&mls,
&mut periodic_table,
&log_up_randomness,
r_sum_check,
);
let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim);

// reseed with the s_0 polynomial
Expand All @@ -198,10 +206,20 @@ pub fn sum_check_prove_higher_degree<
.for_each(|ml| ml.bind_least_significant_variable(round_challenge));
eq_nu.bind_least_significant_variable(round_challenge);

// fold each periodic multi-linear using the round challenge
periodic_table.bind_least_significant_variable(round_challenge);

// run the i-th round of the protocol using the folded multi-linears for the new reduced
// claim. This basically computes the s_i polynomial.
let round_poly_evals =
sumcheck_round(&eq_mu, evaluator, &eq_nu, &mls, &log_up_randomness, r_sum_check);
let round_poly_evals = sumcheck_round(
&eq_mu,
evaluator,
&eq_nu,
&mls,
&mut periodic_table,
&log_up_randomness,
r_sum_check,
);

// update the claim
current_round_claim = new_round_claim;
Expand Down Expand Up @@ -280,21 +298,23 @@ fn sumcheck_round<E: FieldElement>(
evaluator: &impl LogUpGkrEvaluator<BaseField = <E as FieldElement>::BaseField>,
eq_ml: &MultiLinearPoly<E>,
mls: &[MultiLinearPoly<E>],
periodic_table: &mut PeriodicTable<E>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be &mut PeriodicTable<E>? Doesn't seem like we are mutating the table in this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Fixed

log_up_randomness: &[E],
r_sum_check: E,
) -> CompressedUnivariatePolyEvals<E> {
let num_ml = mls.len();
let num_mls = mls.len();
let num_oracles = num_mls + periodic_table.num_columns();
let num_vars = mls[0].num_variables();
let num_rounds = num_vars - 1;

#[cfg(not(feature = "concurrent"))]
let evaluations = {
let mut evals_one = vec![E::ZERO; num_ml];
let mut evals_zero = vec![E::ZERO; num_ml];
let mut evals_x = vec![E::ZERO; num_ml];
let mut evals_one = vec![E::ZERO; num_oracles];
let mut evals_zero = vec![E::ZERO; num_oracles];
let mut evals_x = vec![E::ZERO; num_oracles];
let mut eq_x = E::ZERO;

let mut deltas = vec![E::ZERO; num_ml];
let mut deltas = vec![E::ZERO; num_oracles];
let mut eq_delta = E::ZERO;

let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()];
Expand All @@ -311,6 +331,10 @@ fn sumcheck_round<E: FieldElement>(
let eq_at_zero = eq_ml.evaluations()[2 * i];
let eq_at_one = eq_ml.evaluations()[2 * i + 1];

// add evaluation of periodic columns
periodic_table.fill_periodic_values_at(2 * i, &mut evals_zero[num_mls..]);
periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_one[num_mls..]);

// compute the evaluation at 1
evaluator.evaluate_query(
&evals_one,
Expand All @@ -327,7 +351,7 @@ fn sumcheck_round<E: FieldElement>(
);

// compute the evaluations at 2, ..., d_max points
for i in 0..num_ml {
for i in 0..num_oracles {
deltas[i] = evals_one[i] - evals_zero[i];
evals_x[i] = evals_one[i];
}
Expand Down Expand Up @@ -371,13 +395,13 @@ fn sumcheck_round<E: FieldElement>(
.fold(
|| {
(
vec![E::ZERO; num_ml],
vec![E::ZERO; num_ml],
vec![E::ZERO; num_ml],
vec![E::ZERO; num_oracles],
vec![E::ZERO; num_oracles],
vec![E::ZERO; num_oracles],
vec![E::ZERO; evaluator.max_degree()],
vec![E::ZERO; evaluator.get_num_fractions()],
vec![E::ZERO; evaluator.get_num_fractions()],
vec![E::ZERO; num_ml],
vec![E::ZERO; num_oracles],
)
},
|(
Expand All @@ -398,6 +422,10 @@ fn sumcheck_round<E: FieldElement>(
let eq_at_zero = eq_ml.evaluations()[2 * i];
let eq_at_one = eq_ml.evaluations()[2 * i + 1];

// add evaluation of periodic columns
periodic_table.fill_periodic_values_at(2 * i, &mut evals_zero[num_mls..]);
periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_one[num_mls..]);

// compute the evaluation at 1
evaluator.evaluate_query(
&evals_one,
Expand All @@ -414,7 +442,7 @@ fn sumcheck_round<E: FieldElement>(
);

// compute the evaluations at 2, ..., d_max points
for i in 0..num_ml {
for i in 0..num_oracles {
deltas[i] = evals_one[i] - evals_zero[i];
evals_x[i] = evals_one[i];
}
Expand Down
Loading
Loading