Skip to content

Commit

Permalink
feat: move block hash table to LogUp-GKR
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer committed Sep 23, 2024
1 parent 5436359 commit a18f2b4
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 632 deletions.
170 changes: 161 additions & 9 deletions air/src/logup_gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use winter_air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle};
use crate::{
decoder::{
DECODER_ADDR_COL_IDX, DECODER_GROUP_COUNT_COL_IDX, DECODER_HASHER_STATE_OFFSET,
DECODER_IN_SPAN_COL_IDX, DECODER_OP_BATCH_FLAGS_OFFSET, DECODER_OP_BITS_EXTRA_COLS_OFFSET,
DECODER_OP_BITS_OFFSET, DECODER_USER_OP_HELPERS_OFFSET,
DECODER_IN_SPAN_COL_IDX, DECODER_IS_LOOP_BODY_FLAG_COL_IDX, DECODER_OP_BATCH_FLAGS_OFFSET,
DECODER_OP_BITS_EXTRA_COLS_OFFSET, DECODER_OP_BITS_OFFSET, DECODER_USER_OP_HELPERS_OFFSET,
},
trace::{
chiplets::{MEMORY_D0_COL_IDX, MEMORY_D1_COL_IDX},
Expand All @@ -29,8 +29,13 @@ pub const RANGE_CHECKER_NUM_RAND_VALUES: usize = 1;
pub const OP_GROUP_TABLE_RAND_VALUES_OFFSET: usize = 1;
pub const OP_GROUP_TABLE_NUM_RAND_VALUES: usize = 4;

pub const TOTAL_NUM_RAND_VALUES: usize =
pub const BLOCK_HASH_TABLE_RAND_VALUES_OFFSET: usize =
OP_GROUP_TABLE_RAND_VALUES_OFFSET + OP_GROUP_TABLE_NUM_RAND_VALUES;
pub const BLOCK_HASH_TABLE_NUM_RAND_VALUES: usize = 8;

pub const TOTAL_NUM_RAND_VALUES: usize =
BLOCK_HASH_TABLE_RAND_VALUES_OFFSET + BLOCK_HASH_TABLE_NUM_RAND_VALUES;

// Fractions

pub const RANGE_CHECKER_FRACTIONS_OFFSET: usize = 0;
Expand All @@ -40,8 +45,12 @@ pub const OP_GROUP_TABLE_FRACTIONS_OFFSET: usize =
RANGE_CHECKER_FRACTIONS_OFFSET + RANGE_CHECKER_NUM_FRACTIONS;
pub const OP_GROUP_TABLE_NUM_FRACTIONS: usize = 12;

pub const PADDING_FRACTIONS_OFFSET: usize =
pub const BLOCK_HASH_TABLE_FRACTIONS_OFFSET: usize =
OP_GROUP_TABLE_FRACTIONS_OFFSET + OP_GROUP_TABLE_NUM_FRACTIONS;
pub const BLOCK_HASH_TABLE_NUM_FRACTIONS: usize = 8;

pub const PADDING_FRACTIONS_OFFSET: usize =
BLOCK_HASH_TABLE_FRACTIONS_OFFSET + BLOCK_HASH_TABLE_NUM_FRACTIONS;
pub const PADDING_NUM_FRACTIONS: usize = TOTAL_NUM_FRACTIONS - PADDING_FRACTIONS_OFFSET;

pub const TOTAL_NUM_FRACTIONS: usize = 32;
Expand Down Expand Up @@ -118,6 +127,7 @@ impl LogUpGkrEvaluator for MidenLogUpGkrEval<Felt> {
let query_next = &query[TRACE_WIDTH..];

let op_flags_current = LogUpOpFlags::new(query_current);
let op_flags_next = LogUpOpFlags::new(query_next);

range_checker(
query_current,
Expand All @@ -134,23 +144,41 @@ impl LogUpGkrEvaluator for MidenLogUpGkrEval<Felt> {
&mut numerator[range(OP_GROUP_TABLE_FRACTIONS_OFFSET, OP_GROUP_TABLE_NUM_FRACTIONS)],
&mut denominator[range(OP_GROUP_TABLE_FRACTIONS_OFFSET, OP_GROUP_TABLE_NUM_FRACTIONS)],
);
block_hash_table(
query_current,
query_next,
&op_flags_current,
&op_flags_next,
&rand_values
[range(BLOCK_HASH_TABLE_RAND_VALUES_OFFSET, BLOCK_HASH_TABLE_NUM_RAND_VALUES)],
&mut numerator
[range(BLOCK_HASH_TABLE_FRACTIONS_OFFSET, BLOCK_HASH_TABLE_NUM_FRACTIONS)],
&mut denominator
[range(BLOCK_HASH_TABLE_FRACTIONS_OFFSET, BLOCK_HASH_TABLE_NUM_FRACTIONS)],
);
padding(
&mut numerator[range(PADDING_FRACTIONS_OFFSET, PADDING_NUM_FRACTIONS)],
&mut denominator[range(PADDING_FRACTIONS_OFFSET, PADDING_NUM_FRACTIONS)],
);
}

fn compute_claim<E>(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E
fn compute_claim<E>(&self, inputs: &Self::PublicInputs, rand_values: &[E]) -> E
where
E: FieldElement<BaseField = Self::BaseField>,
{
E::ZERO
// block hash table
let block_hash_table_claim = {
let alphas = &rand_values
[range(BLOCK_HASH_TABLE_RAND_VALUES_OFFSET, BLOCK_HASH_TABLE_NUM_RAND_VALUES)];
let program_hash = inputs.program_info.program_hash();

-(alphas[0] + inner_product(&alphas[2..6], program_hash.as_elements()))
};

block_hash_table_claim.inv()
}
}

// HELPERS
// -----------------------------------------------------------------------------------------------

/// TODO(plafer): docs
#[inline(always)]
fn range_checker<F, E>(
Expand Down Expand Up @@ -289,6 +317,78 @@ fn op_group_table<F, E>(
denominator[11] = v7;
}

#[inline(always)]
fn block_hash_table<F, E>(
query_current: &[F],
query_next: &[F],
op_flags_current: &LogUpOpFlags<F>,
op_flags_next: &LogUpOpFlags<F>,
alphas: &[E],
numerator: &mut [E],
denominator: &mut [E],
) where
F: FieldElement,
E: FieldElement + ExtensionOf<F>,
{
let stack_0 = query_current[STACK_TRACE_OFFSET + STACK_TOP_OFFSET];

// numerators
let f_join: E = op_flags_current.f_join().into();

numerator[0] = op_flags_current.f_end().into();
numerator[1] = f_join;
numerator[2] = f_join;
numerator[3] = op_flags_current.f_split().into();
numerator[4] = (op_flags_current.f_loop() * stack_0).into();
numerator[5] = op_flags_current.f_repeat().into();
numerator[6] = op_flags_current.f_dyn().into();
// TODO(plafer): update docs (no mention of call or syscall)
numerator[7] = (op_flags_current.f_call() + op_flags_current.f_syscall()).into();

// denominators
let addr_next = query_next[DECODER_ADDR_COL_IDX];
let h0_to_3 = &query_current[range(DECODER_HASHER_STATE_OFFSET, 4)];
let h4_to_7 = &query_current[range(DECODER_HASHER_STATE_OFFSET + 4, 4)];
let stack_1 = query_current[STACK_TRACE_OFFSET + STACK_TOP_OFFSET + 1];
let stack_2 = query_current[STACK_TRACE_OFFSET + STACK_TOP_OFFSET + 2];
let stack_3 = query_current[STACK_TRACE_OFFSET + STACK_TOP_OFFSET + 3];
// TODO(plafer): update docs (this is h4 in docs)
let f_is_loop_body = query_current[DECODER_IS_LOOP_BODY_FLAG_COL_IDX];
let child1 = alphas[0] + alphas[1].mul_base(addr_next) + inner_product(&alphas[2..6], h0_to_3);
let child2 = alphas[0] + alphas[1].mul_base(addr_next) + inner_product(&alphas[2..6], h4_to_7);

let u_end = {
// TODO(plafer): update docs (f_halt missing)
let is_first_child =
F::ONE - (op_flags_next.f_end() + op_flags_next.f_repeat() + op_flags_next.f_halt());

// TODO(plafer): Double check addr_next; docs inconsistent with BlockHashTableRow
alphas[0]
+ alphas[1].mul_base(addr_next)
+ inner_product(&alphas[2..6], h0_to_3)
+ alphas[6].mul_base(is_first_child)
+ alphas[7].mul_base(f_is_loop_body)
};

let v_join_1 = child1 + alphas[6];
let v_join_2 = child2;
let v_split = child1.mul_base(stack_0) + child2.mul_base(F::ONE - stack_0);
let v_loop = child1 + alphas[7];
let v_repeat = child1 + alphas[7];
let v_dyn = alphas[0]
+ alphas[1].mul_base(addr_next)
+ inner_product(&alphas[2..6], &[stack_3, stack_2, stack_1, stack_0]);

denominator[0] = -u_end;
denominator[1] = v_join_1;
denominator[2] = v_join_2;
denominator[3] = v_split;
denominator[4] = v_loop;
denominator[5] = v_repeat;
denominator[6] = v_dyn;
denominator[7] = child1;
}

/// TODO(plafer): docs
fn padding<E>(numerator: &mut [E], denominator: &mut [E])
where
Expand All @@ -308,6 +408,7 @@ struct LogUpOpFlags<F: FieldElement> {
b5: F,
b6: F,
e0: F,
e1: F,
}

impl<F: FieldElement> LogUpOpFlags<F> {
Expand All @@ -321,6 +422,7 @@ impl<F: FieldElement> LogUpOpFlags<F> {
b5: query[DECODER_OP_BITS_OFFSET + 5],
b6: query[DECODER_OP_BITS_OFFSET + 6],
e0: query[DECODER_OP_BITS_EXTRA_COLS_OFFSET],
e1: query[DECODER_OP_BITS_EXTRA_COLS_OFFSET + 1],
}
}

Expand Down Expand Up @@ -349,4 +451,54 @@ impl<F: FieldElement> LogUpOpFlags<F> {
pub fn f_range_check(&self) -> F {
(F::ONE - self.b4) * (F::ONE - self.b5) * self.b6
}

pub fn f_join(&self) -> F {
self.e0 * (F::ONE - self.b3) * self.b2 * self.b1 * self.b0
}

pub fn f_split(&self) -> F {
self.e0 * (F::ONE - self.b3) * self.b2 * (F::ONE - self.b1) * (F::ONE - self.b0)
}

pub fn f_loop(&self) -> F {
self.e0 * (F::ONE - self.b3) * self.b2 * (F::ONE - self.b1) * self.b0
}

pub fn f_dyn(&self) -> F {
self.e0 * self.b3 * (F::ONE - self.b2) * (F::ONE - self.b1) * (F::ONE - self.b0)
}

pub fn f_repeat(&self) -> F {
self.e1 * self.b4 * (F::ONE - self.b3) * self.b2
}

pub fn f_end(&self) -> F {
self.e1 * self.b4 * (F::ONE - self.b3) * (F::ONE - self.b2)
}

pub fn f_syscall(&self) -> F {
self.e1 * (F::ONE - self.b4) * self.b3 * (F::ONE - self.b2)
}

pub fn f_call(&self) -> F {
self.e1 * (F::ONE - self.b4) * self.b3 * self.b2
}

pub fn f_halt(&self) -> F {
self.e1 * self.b4 * self.b3 * self.b2
}
}

// HELPERS
// -----------------------------------------------------------------------------------------------

fn inner_product<F, E>(alphas: &[E], eles: &[F]) -> E
where
F: FieldElement,
E: FieldElement + ExtensionOf<F>,
{
alphas
.iter()
.zip(eles.iter())
.fold(E::ZERO, |acc, (alpha, ele)| acc + alpha.mul_base(*ele))
}
5 changes: 2 additions & 3 deletions air/src/trace/decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ pub const IS_SYSCALL_FLAG_COL_IDX: usize = HASHER_STATE_RANGE.start + 7;
/// Running product column representing block stack table.
pub const P1_COL_IDX: usize = DECODER_AUX_TRACE_OFFSET;

/// Running product column representing block hash table
pub const P2_COL_IDX: usize = DECODER_AUX_TRACE_OFFSET + 1;

// --- GLOBALLY-INDEXED DECODER COLUMN ACCESSORS --------------------------------------------------
pub const DECODER_ADDR_COL_IDX: usize = super::DECODER_TRACE_OFFSET + ADDR_COL_IDX;
pub const DECODER_OP_BITS_OFFSET: usize = super::DECODER_TRACE_OFFSET + OP_BITS_OFFSET;
Expand All @@ -109,3 +106,5 @@ pub const DECODER_OP_BATCH_FLAGS_OFFSET: usize =
super::DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET;
pub const DECODER_OP_BITS_EXTRA_COLS_OFFSET: usize =
super::DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET;
pub const DECODER_IS_LOOP_BODY_FLAG_COL_IDX: usize =
super::DECODER_TRACE_OFFSET + IS_LOOP_BODY_FLAG_COL_IDX;
4 changes: 2 additions & 2 deletions air/src/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ pub const TRACE_WIDTH: usize = CHIPLETS_OFFSET + CHIPLETS_WIDTH;
// ------------------------------------------------------------------------------------------------

// decoder stack hasher chiplets
// (2 columns) (1 column) (1 column) (1 column)
// (1 columns) (1 column) (1 column) (1 column)
// ├───────────────┴──────────────┴─────────────────┴───────────────┤

// Decoder auxiliary columns
pub const DECODER_AUX_TRACE_OFFSET: usize = 0;
pub const DECODER_AUX_TRACE_WIDTH: usize = 2;
pub const DECODER_AUX_TRACE_WIDTH: usize = 1;
pub const DECODER_AUX_TRACE_RANGE: Range<usize> =
range(DECODER_AUX_TRACE_OFFSET, DECODER_AUX_TRACE_WIDTH);

Expand Down
Loading

0 comments on commit a18f2b4

Please sign in to comment.