From aa5b90b7a0a627c8fdabdd047fcea42763d8f58d Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Thu, 13 Oct 2022 11:52:49 +0000 Subject: [PATCH 1/5] feat(core): define capacity_col_range and rate_col_range for hasher --- core/src/chiplets/hasher.rs | 16 ++++++++++++++-- core/src/chiplets/mod.rs | 11 +++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/core/src/chiplets/hasher.rs b/core/src/chiplets/hasher.rs index 0de6ec15c0..318f3fc17e 100644 --- a/core/src/chiplets/hasher.rs +++ b/core/src/chiplets/hasher.rs @@ -37,11 +37,23 @@ pub const ROW_COL_IDX: usize = NUM_SELECTORS; /// The hasher state portion of the execution trace, located in 4 .. 16 columns. pub const STATE_COL_RANGE: Range = create_range(ROW_COL_IDX + 1, STATE_WIDTH); +/// Number of field elements in the capacity portion of the hasher's state. +pub const CAPACITY_LEN: usize = STATE_WIDTH - RATE_LEN; + +/// The capacity portion of the hasher state in the execution trace, located in 4 .. 8 columns. +pub const CAPACITY_COL_RANGE: Range = Range { + start: STATE_COL_RANGE.start, + end: STATE_COL_RANGE.start + CAPACITY_LEN, +}; + /// Number of field elements in the rate portion of the hasher's state. pub const RATE_LEN: usize = 8; -/// Number of field elements in the capacity portion of the hasher's state. -pub const CAPACITY_LEN: usize = STATE_WIDTH - RATE_LEN; +/// The rate portion of the hasher state in the execution trace, located in 8 .. 16 columns. +pub const RATE_COL_RANGE: Range = Range { + start: CAPACITY_COL_RANGE.end, + end: CAPACITY_COL_RANGE.end + RATE_LEN, +}; // The length of the output portion of the hash state. pub const DIGEST_LEN: usize = 4; diff --git a/core/src/chiplets/mod.rs b/core/src/chiplets/mod.rs index d0d3c51f75..7ec3aba4ab 100644 --- a/core/src/chiplets/mod.rs +++ b/core/src/chiplets/mod.rs @@ -36,6 +36,17 @@ pub const HASHER_STATE_COL_RANGE: Range = Range { start: HASHER_TRACE_OFFSET + hasher::STATE_COL_RANGE.start, end: HASHER_TRACE_OFFSET + hasher::STATE_COL_RANGE.end, }; +/// The range of columns in the execution trace that contains the capacity portion of the hasher +/// state. +pub const HASHER_CAPACITY_COL_RANGE: Range = Range { + start: HASHER_TRACE_OFFSET + hasher::CAPACITY_COL_RANGE.start, + end: HASHER_TRACE_OFFSET + hasher::CAPACITY_COL_RANGE.end, +}; +/// The range of columns in the execution trace that contains the rate portion of the hasher state. +pub const HASHER_RATE_COL_RANGE: Range = Range { + start: HASHER_TRACE_OFFSET + hasher::RATE_COL_RANGE.start, + end: HASHER_TRACE_OFFSET + hasher::RATE_COL_RANGE.end, +}; /// The index of the hasher's node index column in the execution trace. pub const HASHER_NODE_INDEX_COL_IDX: usize = HASHER_STATE_COL_RANGE.end; From bdc6eca62f961f3698c8bf3675becd4546daa61f Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Thu, 13 Oct 2022 11:53:54 +0000 Subject: [PATCH 2/5] refactor(proc): Remove HasherState from HasherLookup --- processor/src/chiplets/bitwise/mod.rs | 8 +- processor/src/chiplets/bus/aux_trace.rs | 10 +- processor/src/chiplets/bus/mod.rs | 15 ++- processor/src/chiplets/hasher/aux_trace.rs | 12 ++- processor/src/chiplets/hasher/lookups.rs | 105 +++++++++++++++------ processor/src/chiplets/hasher/trace.rs | 1 - processor/src/chiplets/memory/mod.rs | 3 +- processor/src/decoder/aux_hints.rs | 19 +++- processor/src/lib.rs | 2 + processor/src/range/aux_trace.rs | 12 ++- processor/src/range/request.rs | 34 +++++-- processor/src/stack/aux_trace.rs | 2 +- processor/src/stack/overflow.rs | 7 +- processor/src/trace/decoder/mod.rs | 16 +++- processor/src/trace/decoder/tests.rs | 105 ++++++++++++--------- processor/src/trace/tests/hasher.rs | 6 +- processor/src/trace/tests/stack.rs | 8 +- processor/src/trace/utils.rs | 15 ++- 18 files changed, 257 insertions(+), 123 deletions(-) diff --git a/processor/src/chiplets/bitwise/mod.rs b/processor/src/chiplets/bitwise/mod.rs index 86f936320e..6bd5327aaf 100644 --- a/processor/src/chiplets/bitwise/mod.rs +++ b/processor/src/chiplets/bitwise/mod.rs @@ -2,7 +2,7 @@ use super::{ ChipletsBus, ExecutionError, Felt, FieldElement, LookupTableRow, StarkField, TraceFragment, Vec, BITWISE_AND_LABEL, BITWISE_XOR_LABEL, }; -use crate::utils::get_trace_len; +use crate::{utils::get_trace_len, Matrix}; use vm_core::chiplets::bitwise::{ A_COL_IDX, A_COL_RANGE, BITWISE_AND, BITWISE_XOR, B_COL_IDX, B_COL_RANGE, OP_CYCLE_LEN, OUTPUT_COL_IDX, PREV_OUTPUT_COL_IDX, TRACE_WIDTH, @@ -261,7 +261,11 @@ impl BitwiseLookup { impl LookupTableRow for BitwiseLookup { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 5 alpha values. - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>( + &self, + _main_trace: &Matrix, + alphas: &[E], + ) -> E { alphas[0] + alphas[1].mul_base(self.op_id) + alphas[2].mul_base(self.a) diff --git a/processor/src/chiplets/bus/aux_trace.rs b/processor/src/chiplets/bus/aux_trace.rs index b24e507e2f..72669e7530 100644 --- a/processor/src/chiplets/bus/aux_trace.rs +++ b/processor/src/chiplets/bus/aux_trace.rs @@ -1,9 +1,8 @@ use super::{ChipletsLookup, ChipletsLookupRow, Felt, FieldElement}; use crate::{ trace::{build_lookup_table_row_values, AuxColumnBuilder, LookupTableRow}, - Vec, + Matrix, Vec, }; -use winterfell::Matrix; // AUXILIARY TRACE BUILDER // ================================================================================================ @@ -76,7 +75,7 @@ impl AuxColumnBuilder for AuxTraceBuilde /// requests. Since responses are grouped by chiplet, the operation order for the requests and /// responses will be permutations of each other rather than sharing the same order. Therefore, /// the `row_values` and `inv_row_values` must be built separately. - fn build_row_values(&self, _main_trace: &Matrix, alphas: &[E]) -> (Vec, Vec) + fn build_row_values(&self, main_trace: &Matrix, alphas: &[E]) -> (Vec, Vec) where E: FieldElement, { @@ -84,10 +83,11 @@ impl AuxColumnBuilder for AuxTraceBuilde let row_values = self .response_rows .iter() - .map(|response| response.to_value(alphas)) + .map(|response| response.to_value(main_trace, alphas)) .collect(); // get the inverse values from the request rows - let (_, inv_row_values) = build_lookup_table_row_values(&self.request_rows, alphas); + let (_, inv_row_values) = + build_lookup_table_row_values(&self.request_rows, main_trace, alphas); (row_values, inv_row_values) } diff --git a/processor/src/chiplets/bus/mod.rs b/processor/src/chiplets/bus/mod.rs index 85a13f222a..e345c50033 100644 --- a/processor/src/chiplets/bus/mod.rs +++ b/processor/src/chiplets/bus/mod.rs @@ -2,6 +2,7 @@ use super::{ hasher::HasherLookup, BTreeMap, BitwiseLookup, Felt, FieldElement, LookupTableRow, MemoryLookup, Vec, }; +use crate::Matrix; mod aux_trace; pub use aux_trace::AuxTraceBuilder; @@ -210,14 +211,18 @@ pub(super) enum ChipletsLookupRow { } impl LookupTableRow for ChipletsLookupRow { - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>( + &self, + main_trace: &Matrix, + alphas: &[E], + ) -> E { match self { ChipletsLookupRow::HasherMulti(lookups) => lookups .iter() - .fold(E::ONE, |acc, row| acc * row.to_value(alphas)), - ChipletsLookupRow::Hasher(row) => row.to_value(alphas), - ChipletsLookupRow::Bitwise(row) => row.to_value(alphas), - ChipletsLookupRow::Memory(row) => row.to_value(alphas), + .fold(E::ONE, |acc, row| acc * row.to_value(main_trace, alphas)), + ChipletsLookupRow::Hasher(row) => row.to_value(main_trace, alphas), + ChipletsLookupRow::Bitwise(row) => row.to_value(main_trace, alphas), + ChipletsLookupRow::Memory(row) => row.to_value(main_trace, alphas), } } } diff --git a/processor/src/chiplets/hasher/aux_trace.rs b/processor/src/chiplets/hasher/aux_trace.rs index f54753fb4d..79e8d36454 100644 --- a/processor/src/chiplets/hasher/aux_trace.rs +++ b/processor/src/chiplets/hasher/aux_trace.rs @@ -1,6 +1,8 @@ use super::{Felt, FieldElement, StarkField, Vec, Word}; -use crate::trace::{AuxColumnBuilder, LookupTableRow}; -use winterfell::Matrix; +use crate::{ + trace::{AuxColumnBuilder, LookupTableRow}, + Matrix, +}; // AUXILIARY TRACE BUILDER // ================================================================================================ @@ -121,7 +123,11 @@ impl SiblingTableRow { impl LookupTableRow for SiblingTableRow { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 6 alpha values. - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>( + &self, + _main_trace: &Matrix, + alphas: &[E], + ) -> E { // when the least significant bit of the index is 0, the sibling will be in the 3rd word // of the hasher state, and when the least significant bit is 1, it will be in the 2nd // word. we compute the value in this way to make constraint evaluation a bit easier since diff --git a/processor/src/chiplets/hasher/lookups.rs b/processor/src/chiplets/hasher/lookups.rs index 7ad1cdcebe..f59a24406c 100644 --- a/processor/src/chiplets/hasher/lookups.rs +++ b/processor/src/chiplets/hasher/lookups.rs @@ -1,7 +1,17 @@ -use super::{Felt, FieldElement, HasherState, LookupTableRow, StarkField}; -use vm_core::chiplets::hasher::{ - CAPACITY_LEN, DIGEST_RANGE, LINEAR_HASH_LABEL, MP_VERIFY_LABEL, MR_UPDATE_NEW_LABEL, - MR_UPDATE_OLD_LABEL, RETURN_HASH_LABEL, RETURN_STATE_LABEL, STATE_WIDTH, +use core::ops::Range; + +use super::{Felt, FieldElement, LookupTableRow, StarkField}; +use crate::Matrix; +use vm_core::{ + chiplets::{ + hasher::{ + CAPACITY_LEN, DIGEST_LEN, DIGEST_RANGE, LINEAR_HASH_LABEL, MP_VERIFY_LABEL, + MR_UPDATE_NEW_LABEL, MR_UPDATE_OLD_LABEL, RATE_LEN, RETURN_HASH_LABEL, + RETURN_STATE_LABEL, STATE_WIDTH, + }, + HASHER_RATE_COL_RANGE, HASHER_STATE_COL_RANGE, + }, + utils::collections::Vec, }; // CONSTANTS @@ -17,8 +27,7 @@ const NUM_HEADER_ALPHAS: usize = 4; #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum HasherLookupContext { Start, - // TODO: benchmark removing this and getting it from the trace instead - Absorb(HasherState), + Absorb, Return, } @@ -27,9 +36,6 @@ pub enum HasherLookupContext { pub struct HasherLookup { // unique label identifying the hash operation label: u8, - // TODO: benchmark removing this and getting it from the trace instead - // hasher state - state: HasherState, // row address in the Hasher table addr: u32, // node index @@ -40,16 +46,9 @@ pub struct HasherLookup { impl HasherLookup { /// Creates a new HasherLookup. - pub(super) fn new( - label: u8, - state: HasherState, - addr: u32, - index: Felt, - context: HasherLookupContext, - ) -> Self { + pub(super) fn new(label: u8, addr: u32, index: Felt, context: HasherLookupContext) -> Self { Self { label, - state, addr, index, context, @@ -81,7 +80,11 @@ impl HasherLookup { impl LookupTableRow for HasherLookup { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 16 alpha values. - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>( + &self, + main_trace: &Matrix, + alphas: &[E], + ) -> E { let header = self.get_header_value(&alphas[..NUM_HEADER_ALPHAS]); // computing the rest of the value requires an alpha for each element in the [HasherState] let alphas = &alphas[NUM_HEADER_ALPHAS..(NUM_HEADER_ALPHAS + STATE_WIDTH)]; @@ -90,8 +93,14 @@ impl LookupTableRow for HasherLookup { HasherLookupContext::Start => { if self.label == LINEAR_HASH_LABEL { // include the entire state when initializing a linear hash. - header + build_value(alphas, &self.state) + header + + build_value( + alphas, + &get_hasher_state_at(self.addr, main_trace, 0..STATE_WIDTH), + ) } else { + let state = + &get_hasher_state_at(self.addr, main_trace, CAPACITY_LEN..STATE_WIDTH); assert!( self.label == MR_UPDATE_OLD_LABEL || self.label == MR_UPDATE_NEW_LABEL @@ -103,37 +112,45 @@ impl LookupTableRow for HasherLookup { // by the index bit will be the leaf node, and the value must be computed in the // same way in both cases. let bit = (self.index.as_int() >> 1) & 1; - let left_word = build_value(&alphas[DIGEST_RANGE], &self.state[DIGEST_RANGE]); - let right_word = - build_value(&alphas[DIGEST_RANGE], &self.state[DIGEST_RANGE.end..]); + let left_word = build_value(&alphas[DIGEST_RANGE], &state[..DIGEST_LEN]); + let right_word = build_value(&alphas[DIGEST_RANGE], &state[DIGEST_LEN..]); header + E::from(1 - bit).mul(left_word) + E::from(bit).mul(right_word) } } - HasherLookupContext::Absorb(next_state) => { + HasherLookupContext::Absorb => { assert!( self.label == LINEAR_HASH_LABEL, "unrecognized hash operation" ); + let (curr_hasher_rate, next_hasher_rate) = + get_adjacent_hasher_rates(self.addr, main_trace); // build the value from the delta of the hasher state's rate before and after the // absorption of new elements. - let next_state_value = - build_value(&alphas[CAPACITY_LEN..], &next_state[CAPACITY_LEN..]); - let state_value = build_value(&alphas[CAPACITY_LEN..], &self.state[CAPACITY_LEN..]); + let next_state_value = build_value(&alphas[CAPACITY_LEN..], &next_hasher_rate); + let state_value = build_value(&alphas[CAPACITY_LEN..], &curr_hasher_rate); header + next_state_value - state_value } HasherLookupContext::Return => { if self.label == RETURN_STATE_LABEL { // build the value from the result, which is the entire state - header + build_value(alphas, &self.state) + header + + build_value( + alphas, + &get_hasher_state_at(self.addr, main_trace, 0..STATE_WIDTH), + ) } else { assert!( self.label == RETURN_HASH_LABEL, "unrecognized hash operation" ); // build the value from the result, which is the digest portion of the state - header + build_value(&alphas[DIGEST_RANGE], &self.state[DIGEST_RANGE]) + header + + build_value( + &alphas[DIGEST_RANGE], + &get_hasher_state_at(self.addr, main_trace, DIGEST_RANGE), + ) } } } @@ -153,3 +170,35 @@ fn build_value>(alphas: &[E], elements: &[Felt } value } + +/// Returns the portion of the hasher state at the provided address that is within the provided +/// column range. +fn get_hasher_state_at(addr: u32, main_trace: &Matrix, col_range: Range) -> Vec { + let row = get_row_from_addr(addr); + col_range + .map(|col| main_trace.get(HASHER_STATE_COL_RANGE.start + col, row)) + .collect::>() +} + +/// Returns the rate portion of the hasher state for the provided row and the next row. +fn get_adjacent_hasher_rates( + addr: u32, + main_trace: &Matrix, +) -> ([Felt; RATE_LEN], [Felt; RATE_LEN]) { + let row = get_row_from_addr(addr); + + let mut current = [Felt::ZERO; RATE_LEN]; + let mut next = [Felt::ZERO; RATE_LEN]; + for (idx, col_idx) in HASHER_RATE_COL_RANGE.enumerate() { + let column = main_trace.get_column(col_idx); + current[idx] = column[row]; + next[idx] = column[row + 1]; + } + + (current, next) +} + +/// Gets the row index from the specified row address. +fn get_row_from_addr(addr: u32) -> usize { + addr as usize - 1 +} diff --git a/processor/src/chiplets/hasher/trace.rs b/processor/src/chiplets/hasher/trace.rs index 3a1b42a125..1698d889c2 100644 --- a/processor/src/chiplets/hasher/trace.rs +++ b/processor/src/chiplets/hasher/trace.rs @@ -122,7 +122,6 @@ impl HasherTrace { self.append_row(selectors, &hasher_state, node_index); } - // TODO: remove this after the state is removed from the HasherLookup struct // copy the latest hasher state to the provided state slice for (state_col, hasher_col) in state.iter_mut().zip(hasher_state.iter()) { *state_col = *hasher_col diff --git a/processor/src/chiplets/memory/mod.rs b/processor/src/chiplets/memory/mod.rs index 3c47892775..a8527534f6 100644 --- a/processor/src/chiplets/memory/mod.rs +++ b/processor/src/chiplets/memory/mod.rs @@ -5,6 +5,7 @@ use crate::{ range::RangeChecker, trace::LookupTableRow, utils::{split_element_u32_into_u16, split_u32_into_u16}, + Matrix, }; use vm_core::chiplets::memory::{ ADDR_COL_IDX, CLK_COL_IDX, CTX_COL_IDX, D0_COL_IDX, D1_COL_IDX, D_INV_COL_IDX, V_COL_RANGE, @@ -333,7 +334,7 @@ impl MemoryLookup { impl LookupTableRow for MemoryLookup { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 9 alpha values. - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>(&self, _main_trace: &Matrix, alphas: &[E]) -> E { let word_value = self .word .iter() diff --git a/processor/src/decoder/aux_hints.rs b/processor/src/decoder/aux_hints.rs index 0e9a7b5791..199737eec8 100644 --- a/processor/src/decoder/aux_hints.rs +++ b/processor/src/decoder/aux_hints.rs @@ -2,6 +2,7 @@ use super::{ super::trace::LookupTableRow, get_num_groups_in_next_batch, BlockInfo, Felt, FieldElement, StarkField, Vec, Word, ONE, ZERO, }; +use crate::Matrix; // AUXILIARY TRACE HINTS // ================================================================================================ @@ -343,7 +344,11 @@ impl BlockStackTableRow { impl LookupTableRow for BlockStackTableRow { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 8 alpha values. - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>( + &self, + _main_trace: &Matrix, + alphas: &[E], + ) -> E { let is_loop = if self.is_loop { ONE } else { ZERO }; alphas[0] + alphas[1].mul_base(self.block_id) @@ -421,7 +426,11 @@ impl BlockHashTableRow { impl LookupTableRow for BlockHashTableRow { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 8 alpha values. - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>( + &self, + _main_trace: &Matrix, + alphas: &[E], + ) -> E { let is_first_child = if self.is_first_child { ONE } else { ZERO }; let is_loop_body = if self.is_loop_body { ONE } else { ZERO }; alphas[0] @@ -461,7 +470,11 @@ impl OpGroupTableRow { impl LookupTableRow for OpGroupTableRow { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 4 alpha values. - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>( + &self, + _main_trace: &Matrix, + alphas: &[E], + ) -> E { alphas[0] + alphas[1].mul_base(self.batch_id) + alphas[2].mul_base(self.group_pos) diff --git a/processor/src/lib.rs b/processor/src/lib.rs index c9dde373bf..21480a1d88 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -19,6 +19,8 @@ use vm_core::{ ONE, RANGE_CHECK_TRACE_WIDTH, STACK_TRACE_WIDTH, SYS_TRACE_WIDTH, ZERO, }; +use winterfell::Matrix; + mod decorators; mod operations; diff --git a/processor/src/range/aux_trace.rs b/processor/src/range/aux_trace.rs index bea223b64f..e47aa81078 100644 --- a/processor/src/range/aux_trace.rs +++ b/processor/src/range/aux_trace.rs @@ -1,8 +1,10 @@ use vm_core::{range::V_COL_IDX, utils::uninit_vector}; -use winterfell::Matrix; use super::{BTreeMap, CycleRangeChecks, Felt, FieldElement, RangeCheckFlag, Vec}; -use crate::trace::{build_lookup_table_row_values, NUM_RAND_ROWS}; +use crate::{ + trace::{build_lookup_table_row_values, NUM_RAND_ROWS}, + Matrix, +}; // AUXILIARY TRACE BUILDER // ================================================================================================ @@ -143,7 +145,7 @@ impl AuxTraceBuilder { ) -> (Vec, Vec) { // compute the inverses for range checks performed by operations. let (_, inv_row_values) = - build_lookup_table_row_values(&self.cycle_range_check_values(), alphas); + build_lookup_table_row_values(&self.cycle_range_check_values(), main_trace, alphas); // allocate memory for the running product column and set the initial value to ONE let mut q = unsafe { uninit_vector(main_trace.num_rows()) }; @@ -175,7 +177,7 @@ impl AuxTraceBuilder { p1_idx = clk + 1; // update the intermediate values in the q column. - q[clk] = range_checks.to_stack_value(alphas); + q[clk] = range_checks.to_stack_value(alphas, main_trace); // include the operation lookups in the running product. p1[p1_idx] = p1[clk] * inv_row_values[rc_user_op_idx]; @@ -206,7 +208,7 @@ impl AuxTraceBuilder { if let Some(range_check) = self.cycle_range_checks.get(&(row_idx as u32)) { // update the intermediate values in the q column. - q[row_idx] = range_check.to_stack_value(alphas); + q[row_idx] = range_check.to_stack_value(alphas, main_trace); // include the operation lookups in the running product. p1[p1_idx] *= inv_row_values[rc_user_op_idx]; diff --git a/processor/src/range/request.rs b/processor/src/range/request.rs index 97b12e2750..5d8e7b1f4e 100644 --- a/processor/src/range/request.rs +++ b/processor/src/range/request.rs @@ -1,5 +1,5 @@ use super::{Felt, FieldElement}; -use crate::trace::LookupTableRow; +use crate::{trace::LookupTableRow, Matrix}; // PROCESSOR RANGE CHECKS // ================================================================================================ @@ -58,11 +58,15 @@ impl CycleRangeChecks { /// Reduces all range checks requested at this cycle by the Stack processor to a single field /// element in the field specified by E. - pub fn to_stack_value>(&self, alphas: &[E]) -> E { + pub fn to_stack_value>( + &self, + alphas: &[E], + main_trace: &Matrix, + ) -> E { let mut value = E::ONE; if let Some(stack_checks) = &self.stack { - value *= stack_checks.to_value(alphas); + value *= stack_checks.to_value(main_trace, alphas); } value @@ -70,11 +74,15 @@ impl CycleRangeChecks { /// Reduces all range checks requested at this cycle by the Memory processor to a single field /// element in the field specified by E. - fn to_mem_value>(&self, alphas: &[E]) -> E { + fn to_mem_value>( + &self, + alphas: &[E], + main_trace: &Matrix, + ) -> E { let mut value = E::ONE; if let Some(mem_checks) = &self.memory { - value = mem_checks.to_value(alphas); + value = mem_checks.to_value(main_trace, alphas) } value @@ -84,9 +92,13 @@ impl CycleRangeChecks { impl LookupTableRow for CycleRangeChecks { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 1 alpha value. Includes all values included at this cycle from all processors. - fn to_value>(&self, alphas: &[E]) -> E { - let stack_value = self.to_stack_value(alphas); - let mem_value = self.to_mem_value(alphas); + fn to_value>( + &self, + main_trace: &Matrix, + alphas: &[E], + ) -> E { + let stack_value = self.to_stack_value(alphas, main_trace); + let mem_value = self.to_mem_value(alphas, main_trace); if stack_value != E::ONE { stack_value * mem_value @@ -107,7 +119,11 @@ enum RangeCheckRequest { impl LookupTableRow for RangeCheckRequest { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 1 alpha value. - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>( + &self, + _main_trace: &Matrix, + alphas: &[E], + ) -> E { let alpha: E = alphas[0]; let mut value = E::ONE; diff --git a/processor/src/stack/aux_trace.rs b/processor/src/stack/aux_trace.rs index 5d72fe46f0..00415055b1 100644 --- a/processor/src/stack/aux_trace.rs +++ b/processor/src/stack/aux_trace.rs @@ -1,7 +1,7 @@ use super::{ super::trace::AuxColumnBuilder, Felt, FieldElement, OverflowTableRow, OverflowTableUpdate, Vec, }; -use winterfell::Matrix; +use crate::Matrix; // AUXILIARY TRACE BUILDER // ================================================================================================ diff --git a/processor/src/stack/overflow.rs b/processor/src/stack/overflow.rs index 70e42e32c4..eda2afc056 100644 --- a/processor/src/stack/overflow.rs +++ b/processor/src/stack/overflow.rs @@ -1,6 +1,7 @@ use super::{ super::trace::LookupTableRow, AuxTraceBuilder, BTreeMap, Felt, FieldElement, Vec, ZERO, }; +use crate::Matrix; use vm_core::{utils::uninit_vector, StarkField}; // OVERFLOW TABLE @@ -273,7 +274,11 @@ impl OverflowTableRow { impl LookupTableRow for OverflowTableRow { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 4 alpha values. - fn to_value>(&self, alphas: &[E]) -> E { + fn to_value>( + &self, + _main_trace: &Matrix, + alphas: &[E], + ) -> E { alphas[0] + alphas[1].mul_base(self.clk) + alphas[2].mul_base(self.val) diff --git a/processor/src/trace/decoder/mod.rs b/processor/src/trace/decoder/mod.rs index 75406697e8..a153029199 100644 --- a/processor/src/trace/decoder/mod.rs +++ b/processor/src/trace/decoder/mod.rs @@ -22,7 +22,12 @@ pub fn build_aux_columns>( ) -> Vec> { let p1 = build_aux_col_p1(main_trace, aux_trace_hints, rand_elements); let p2 = build_aux_col_p2(main_trace, aux_trace_hints, rand_elements); - let p3 = build_aux_col_p3(main_trace.num_rows(), aux_trace_hints, rand_elements); + let p3 = build_aux_col_p3( + main_trace.num_rows(), + aux_trace_hints, + rand_elements, + main_trace, + ); vec![p1, p2, p3] } @@ -38,7 +43,8 @@ fn build_aux_col_p1>( ) -> Vec { // compute row values and their inverses for all rows that were added to the block stack table let table_rows = aux_trace_hints.block_stack_table_rows(); - let (row_values, inv_row_values) = build_lookup_table_row_values(table_rows, alphas); + let (row_values, inv_row_values) = + build_lookup_table_row_values(table_rows, main_trace, alphas); // allocate memory for the running product column and set the initial value to ONE let mut result = unsafe { uninit_vector(main_trace.num_rows()) }; @@ -121,7 +127,8 @@ fn build_aux_col_p2>( ) -> Vec { // compute row values and their inverses for all rows that were added to the block hash table let table_rows = aux_trace_hints.block_hash_table_rows(); - let (row_values, inv_row_values) = build_lookup_table_row_values(table_rows, alphas); + let (row_values, inv_row_values) = + build_lookup_table_row_values(table_rows, main_trace, alphas); // initialize memory for the running product column, and set the first value in the column to // the value of the first row (which represents an entry for the root block of the program) @@ -227,6 +234,7 @@ fn build_aux_col_p3>( trace_len: usize, aux_trace_hints: &AuxTraceHints, alphas: &[E], + main_trace: &Matrix, ) -> Vec { // allocate memory for the column and set the starting value to ONE let mut result = unsafe { uninit_vector(trace_len) }; @@ -234,7 +242,7 @@ fn build_aux_col_p3>( // compute row values and their inverses for all rows which were added to the op group table let (row_values, inv_row_values) = - build_lookup_table_row_values(aux_trace_hints.op_group_table_rows(), alphas); + build_lookup_table_row_values(aux_trace_hints.op_group_table_rows(), main_trace, alphas); // keep track of indexes into the list of op group table rows separately for inserted and // removed rows diff --git a/processor/src/trace/decoder/tests.rs b/processor/src/trace/decoder/tests.rs index 0cfbc7e28a..12a0f86f5a 100644 --- a/processor/src/trace/decoder/tests.rs +++ b/processor/src/trace/decoder/tests.rs @@ -27,8 +27,9 @@ fn decoder_p1_span_with_respan() { let p1 = aux_columns.get_column(P1_COL_IDX); let row_values = [ - BlockStackTableRow::new_test(ONE, ZERO, false).to_value(&alphas), - BlockStackTableRow::new_test(Felt::new(9), ZERO, false).to_value(&alphas), + BlockStackTableRow::new_test(ONE, ZERO, false).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(Felt::new(9), ZERO, false) + .to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is ONE @@ -75,9 +76,9 @@ fn decoder_p1_join() { let a_9 = Felt::new(9); let a_17 = Felt::new(17); let row_values = [ - BlockStackTableRow::new_test(ONE, ZERO, false).to_value(&alphas), - BlockStackTableRow::new_test(a_9, ONE, false).to_value(&alphas), - BlockStackTableRow::new_test(a_17, ONE, false).to_value(&alphas), + BlockStackTableRow::new_test(ONE, ZERO, false).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(a_9, ONE, false).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(a_17, ONE, false).to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is ONE @@ -134,8 +135,8 @@ fn decoder_p1_split() { let a_9 = Felt::new(9); let row_values = [ - BlockStackTableRow::new_test(ONE, ZERO, false).to_value(&alphas), - BlockStackTableRow::new_test(a_9, ONE, false).to_value(&alphas), + BlockStackTableRow::new_test(ONE, ZERO, false).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(a_9, ONE, false).to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is ONE @@ -187,13 +188,13 @@ fn decoder_p1_loop_with_repeat() { let a_41 = Felt::new(41); // address of the first SPAN block in the second iteration let a_49 = Felt::new(49); // address of the second SPAN block in the second iteration let row_values = [ - BlockStackTableRow::new_test(ONE, ZERO, true).to_value(&alphas), - BlockStackTableRow::new_test(a_9, ONE, false).to_value(&alphas), - BlockStackTableRow::new_test(a_17, a_9, false).to_value(&alphas), - BlockStackTableRow::new_test(a_25, a_9, false).to_value(&alphas), - BlockStackTableRow::new_test(a_33, ONE, false).to_value(&alphas), - BlockStackTableRow::new_test(a_41, a_33, false).to_value(&alphas), - BlockStackTableRow::new_test(a_49, a_33, false).to_value(&alphas), + BlockStackTableRow::new_test(ONE, ZERO, true).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(a_9, ONE, false).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(a_17, a_9, false).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(a_25, a_9, false).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(a_33, ONE, false).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(a_41, a_33, false).to_value(&trace.main_trace, &alphas), + BlockStackTableRow::new_test(a_49, a_33, false).to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is ONE @@ -294,8 +295,10 @@ fn decoder_p2_span_with_respan() { let aux_columns = trace.build_aux_segment(&[], &alphas).unwrap(); let p2 = aux_columns.get_column(P2_COL_IDX); - let row_values = - [BlockHashTableRow::new_test(ZERO, span.hash().into(), false, false).to_value(&alphas)]; + let row_values = [ + BlockHashTableRow::new_test(ZERO, span.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), + ]; // make sure the first entry is initialized to program hash let mut expected_value = row_values[0]; @@ -327,9 +330,12 @@ fn decoder_p2_join() { let p2 = aux_columns.get_column(P2_COL_IDX); let row_values = [ - BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false).to_value(&alphas), - BlockHashTableRow::new_test(ONE, span1.hash().into(), true, false).to_value(&alphas), - BlockHashTableRow::new_test(ONE, span2.hash().into(), false, false).to_value(&alphas), + BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), + BlockHashTableRow::new_test(ONE, span1.hash().into(), true, false) + .to_value(&trace.main_trace, &alphas), + BlockHashTableRow::new_test(ONE, span2.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is initialized to program hash @@ -380,8 +386,10 @@ fn decoder_p2_split_true() { let p2 = aux_columns.get_column(P2_COL_IDX); let row_values = [ - BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false).to_value(&alphas), - BlockHashTableRow::new_test(ONE, span1.hash().into(), false, false).to_value(&alphas), + BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), + BlockHashTableRow::new_test(ONE, span1.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is initialized to program hash @@ -424,8 +432,10 @@ fn decoder_p2_split_false() { let p2 = aux_columns.get_column(P2_COL_IDX); let row_values = [ - BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false).to_value(&alphas), - BlockHashTableRow::new_test(ONE, span2.hash().into(), false, false).to_value(&alphas), + BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), + BlockHashTableRow::new_test(ONE, span2.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is initialized to program hash @@ -471,12 +481,18 @@ fn decoder_p2_loop_with_repeat() { let a_9 = Felt::new(9); // address of the JOIN block in the first iteration let a_33 = Felt::new(33); // address of the JOIN block in the second iteration let row_values = [ - BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false).to_value(&alphas), - BlockHashTableRow::new_test(ONE, body.hash().into(), false, true).to_value(&alphas), - BlockHashTableRow::new_test(a_9, span1.hash().into(), true, false).to_value(&alphas), - BlockHashTableRow::new_test(a_9, span2.hash().into(), false, false).to_value(&alphas), - BlockHashTableRow::new_test(a_33, span1.hash().into(), true, false).to_value(&alphas), - BlockHashTableRow::new_test(a_33, span2.hash().into(), false, false).to_value(&alphas), + BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), + BlockHashTableRow::new_test(ONE, body.hash().into(), false, true) + .to_value(&trace.main_trace, &alphas), + BlockHashTableRow::new_test(a_9, span1.hash().into(), true, false) + .to_value(&trace.main_trace, &alphas), + BlockHashTableRow::new_test(a_9, span2.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), + BlockHashTableRow::new_test(a_33, span1.hash().into(), true, false) + .to_value(&trace.main_trace, &alphas), + BlockHashTableRow::new_test(a_33, span2.hash().into(), false, false) + .to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is initialized to program hash @@ -602,10 +618,12 @@ fn decoder_p3_trace_one_batch() { // make sure 3 groups were inserted at clock cycle 1; these entries are for the two immediate // values and the second operation group consisting of [SWAP, MUL, ADD] - let g1_value = OpGroupTableRow::new(ONE, Felt::new(3), ONE).to_value(&alphas); - let g2_value = OpGroupTableRow::new(ONE, Felt::new(2), Felt::new(2)).to_value(&alphas); - let g3_value = - OpGroupTableRow::new(ONE, Felt::new(1), build_op_group(&ops[9..])).to_value(&alphas); + let g1_value = + OpGroupTableRow::new(ONE, Felt::new(3), ONE).to_value(&trace.main_trace, &alphas); + let g2_value = + OpGroupTableRow::new(ONE, Felt::new(2), Felt::new(2)).to_value(&trace.main_trace, &alphas); + let g3_value = OpGroupTableRow::new(ONE, Felt::new(1), build_op_group(&ops[9..])) + .to_value(&trace.main_trace, &alphas); let expected_value = g1_value * g2_value * g3_value; assert_eq!(expected_value, p3[1]); @@ -656,13 +674,13 @@ fn decoder_p3_trace_two_batches() { // --- first batch ---------------------------------------------------------------------------- // make sure entries for 7 groups were inserted at clock cycle 1 let b0_values = [ - OpGroupTableRow::new(ONE, Felt::new(11), iv[0]).to_value(&alphas), - OpGroupTableRow::new(ONE, Felt::new(10), iv[1]).to_value(&alphas), - OpGroupTableRow::new(ONE, Felt::new(9), iv[2]).to_value(&alphas), - OpGroupTableRow::new(ONE, Felt::new(8), iv[3]).to_value(&alphas), - OpGroupTableRow::new(ONE, Felt::new(7), iv[4]).to_value(&alphas), - OpGroupTableRow::new(ONE, Felt::new(6), iv[5]).to_value(&alphas), - OpGroupTableRow::new(ONE, Felt::new(5), iv[6]).to_value(&alphas), + OpGroupTableRow::new(ONE, Felt::new(11), iv[0]).to_value(&trace.main_trace, &alphas), + OpGroupTableRow::new(ONE, Felt::new(10), iv[1]).to_value(&trace.main_trace, &alphas), + OpGroupTableRow::new(ONE, Felt::new(9), iv[2]).to_value(&trace.main_trace, &alphas), + OpGroupTableRow::new(ONE, Felt::new(8), iv[3]).to_value(&trace.main_trace, &alphas), + OpGroupTableRow::new(ONE, Felt::new(7), iv[4]).to_value(&trace.main_trace, &alphas), + OpGroupTableRow::new(ONE, Felt::new(6), iv[5]).to_value(&trace.main_trace, &alphas), + OpGroupTableRow::new(ONE, Felt::new(5), iv[6]).to_value(&trace.main_trace, &alphas), ]; let mut expected_value: Felt = b0_values.iter().fold(ONE, |acc, &val| acc * val); assert_eq!(expected_value, p3[1]); @@ -685,9 +703,10 @@ fn decoder_p3_trace_two_batches() { let batch1_addr = ONE + Felt::new(8); let op_group3 = build_op_group(&[Operation::Drop; 2]); let b1_values = [ - OpGroupTableRow::new(batch1_addr, Felt::new(3), iv[7]).to_value(&alphas), - OpGroupTableRow::new(batch1_addr, Felt::new(2), iv[8]).to_value(&alphas), - OpGroupTableRow::new(batch1_addr, Felt::new(1), op_group3).to_value(&alphas), + OpGroupTableRow::new(batch1_addr, Felt::new(3), iv[7]).to_value(&trace.main_trace, &alphas), + OpGroupTableRow::new(batch1_addr, Felt::new(2), iv[8]).to_value(&trace.main_trace, &alphas), + OpGroupTableRow::new(batch1_addr, Felt::new(1), op_group3) + .to_value(&trace.main_trace, &alphas), ]; let mut expected_value: Felt = b1_values.iter().fold(ONE, |acc, &val| acc * val); assert_eq!(expected_value, p3[10]); diff --git a/processor/src/trace/tests/hasher.rs b/processor/src/trace/tests/hasher.rs index e617d8c36a..2e96dfb685 100644 --- a/processor/src/trace/tests/hasher.rs +++ b/processor/src/trace/tests/hasher.rs @@ -66,9 +66,9 @@ fn hasher_p1_mr_update() { let p1 = aux_columns.get_column(P1_COL_IDX); let row_values = [ - SiblingTableRow::new(Felt::new(index), path[0]).to_value(&alphas), - SiblingTableRow::new(Felt::new(index >> 1), path[1]).to_value(&alphas), - SiblingTableRow::new(Felt::new(index >> 2), path[2]).to_value(&alphas), + SiblingTableRow::new(Felt::new(index), path[0]).to_value(&trace.main_trace, &alphas), + SiblingTableRow::new(Felt::new(index >> 1), path[1]).to_value(&trace.main_trace, &alphas), + SiblingTableRow::new(Felt::new(index >> 2), path[2]).to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is ONE diff --git a/processor/src/trace/tests/stack.rs b/processor/src/trace/tests/stack.rs index 2eb6b60117..69fac920fa 100644 --- a/processor/src/trace/tests/stack.rs +++ b/processor/src/trace/tests/stack.rs @@ -37,10 +37,10 @@ fn p1_trace() { let p1 = aux_columns.get_column(P1_COL_IDX); let row_values = [ - OverflowTableRow::new(2, ONE, ZERO).to_value(&alphas), - OverflowTableRow::new(3, TWO, TWO).to_value(&alphas), - OverflowTableRow::new(6, TWO, TWO).to_value(&alphas), - OverflowTableRow::new(10, ZERO, ZERO).to_value(&alphas), + OverflowTableRow::new(2, ONE, ZERO).to_value(&trace.main_trace, &alphas), + OverflowTableRow::new(3, TWO, TWO).to_value(&trace.main_trace, &alphas), + OverflowTableRow::new(6, TWO, TWO).to_value(&trace.main_trace, &alphas), + OverflowTableRow::new(10, ZERO, ZERO).to_value(&trace.main_trace, &alphas), ]; // make sure the first entry is ONE diff --git a/processor/src/trace/utils.rs b/processor/src/trace/utils.rs index 640b52e404..7a9713ffe6 100644 --- a/processor/src/trace/utils.rs +++ b/processor/src/trace/utils.rs @@ -74,7 +74,11 @@ impl<'a> TraceFragment<'a> { pub trait LookupTableRow { /// Returns a single element representing the row in the field defined by E. The value is /// computed using the provided random values. - fn to_value>(&self, rand_values: &[E]) -> E; + fn to_value>( + &self, + main_trace: &Matrix, + rand_values: &[E], + ) -> E; } /// Computes values as well as inverse value for all specified lookup table rows. @@ -85,7 +89,8 @@ pub trait LookupTableRow { /// computationally infeasible. pub fn build_lookup_table_row_values, R: LookupTableRow>( rows: &[R], - rand_values: &[E], + main_trace: &Matrix, + rand_values: &[E], ) -> (Vec, Vec) { let mut row_values = unsafe { uninit_vector(rows.len()) }; let mut inv_row_values = unsafe { uninit_vector(rows.len()) }; @@ -98,7 +103,7 @@ pub fn build_lookup_table_row_values, R: Looku .zip(inv_row_values.iter_mut()) { *inv_value = acc; - *value = row.to_value(rand_values); + *value = row.to_value(main_trace, rand_values); debug_assert_ne!(*value, E::ZERO, "row value cannot be ZERO"); acc *= *value; @@ -195,11 +200,11 @@ pub trait AuxColumnBuilder { /// Builds and returns row values and their inverses for all rows which were added to the /// lookup table managed by this column builder. - fn build_row_values(&self, _main_trace: &Matrix, alphas: &[E]) -> (Vec, Vec) + fn build_row_values(&self, main_trace: &Matrix, alphas: &[E]) -> (Vec, Vec) where E: FieldElement, { - build_lookup_table_row_values(self.get_table_rows(), alphas) + build_lookup_table_row_values(self.get_table_rows(), main_trace, alphas) } /// Returns the initial value in the auxiliary column. Default implementation of this method From 1154bf77ed0c9d2800df2eaf142cde624e251a52 Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Thu, 13 Oct 2022 11:54:15 +0000 Subject: [PATCH 3/5] refactor(proc): Refactor hash_span_block --- processor/src/chiplets/hasher/mod.rs | 136 ++++++++++++--------------- 1 file changed, 60 insertions(+), 76 deletions(-) diff --git a/processor/src/chiplets/hasher/mod.rs b/processor/src/chiplets/hasher/mod.rs index ccba986a6b..bf0a003b70 100644 --- a/processor/src/chiplets/hasher/mod.rs +++ b/processor/src/chiplets/hasher/mod.rs @@ -77,9 +77,7 @@ pub struct Hasher { trace: HasherTrace, aux_trace: AuxTraceBuilder, // TODO: Investigate optimization options, since these lookups are also stored in the bus. - // 1. HasherLookup can be lightened to reduce the cost by removing the state from it and looking - // it up in the execution trace when the lookup values are computed and to b_chip. - // 2. The Hasher could "provide" lookups immediately instead of storing them and providing them + // - The Hasher could "provide" lookups immediately instead of storing them and providing them // during `fill_trace`. // There are probably other options as well, so this should be investigated & benchmarked. lookups: Vec, @@ -103,13 +101,7 @@ impl Hasher { /// When starting a hash operation, it should be called before any rows are recorded in the /// trace. In all other cases, it should be called immediately after the corresponding row is /// appended to the trace, so that the address of the row is equal to the trace length. - fn append_lookup( - &mut self, - label: u8, - state: HasherState, - index: Felt, - context: HasherLookupContext, - ) { + fn append_lookup(&mut self, label: u8, index: Felt, context: HasherLookupContext) { let addr = match context { // when starting a new hash operation, lookups are added before the operation begins. HasherLookupContext::Start => self.trace.next_row_addr().as_int() as u32, @@ -118,7 +110,19 @@ impl Hasher { }; self.lookups - .push(HasherLookup::new(label, state, addr, index, context)); + .push(HasherLookup::new(label, addr, index, context)); + } + + /// Records a HasherLookup with the specified data at the specified row address. + fn append_lookup_at( + &mut self, + addr: usize, + label: u8, + index: Felt, + context: HasherLookupContext, + ) { + self.lookups + .push(HasherLookup::new(label, addr as u32, index, context)); } /// Returns the index at which the next lookup will be appended. @@ -150,14 +154,14 @@ impl Hasher { let init_lookup_idx = self.next_lookup_idx(); // add the lookup for the hash initialization. - self.append_lookup(LINEAR_HASH_LABEL, state, ZERO, HasherLookupContext::Start); + self.append_lookup(LINEAR_HASH_LABEL, ZERO, HasherLookupContext::Start); // perform the hash. self.trace .append_permutation(&mut state, LINEAR_HASH, RETURN_STATE); // add the lookup for the hash result. - self.append_lookup(RETURN_STATE_LABEL, state, ZERO, HasherLookupContext::Return); + self.append_lookup(RETURN_STATE_LABEL, ZERO, HasherLookupContext::Return); let lookups = self.get_last_lookups(init_lookup_idx); (addr, state, lookups) @@ -182,7 +186,7 @@ impl Hasher { let mut state = init_state_from_words(&h1, &h2); // add the lookup for the hash initialization. - self.append_lookup(LINEAR_HASH_LABEL, state, ZERO, HasherLookupContext::Start); + self.append_lookup(LINEAR_HASH_LABEL, ZERO, HasherLookupContext::Start); if let Some((start_row, end_row)) = self.get_memoized_trace(expected_hash) { // copy the trace of a block with same hash instead of building it again. @@ -196,7 +200,7 @@ impl Hasher { }; // add the lookup for the hash result. - self.append_lookup(RETURN_HASH_LABEL, state, ZERO, HasherLookupContext::Return); + self.append_lookup(RETURN_HASH_LABEL, ZERO, HasherLookupContext::Return); let result = get_digest(&state); let lookups = self.get_last_lookups(init_lookup_idx); @@ -212,8 +216,6 @@ impl Hasher { /// The returned tuple also contains the row address of the execution trace at which the hash /// computation started and the lookups required to verify the computation so that the correct /// requests can be sent by the caller to the Chiplets Bus. - /// TODO: Refactor to require fewer is_memoized checks. This can be done if the intermediary - /// states don't need to be passed to the lookup. pub(super) fn hash_span_block( &mut self, op_batches: &[OpBatch], @@ -240,7 +242,7 @@ impl Hasher { let mut state = init_state(op_batches[0].groups(), num_op_groups); // add the lookup for the hash initialization. - self.append_lookup(START_LABEL, state, ZERO, HasherLookupContext::Start); + self.append_lookup(START_LABEL, ZERO, HasherLookupContext::Start); // check if a span block with same hash has been encountered before in which case we can // directly copy it's trace. @@ -252,78 +254,60 @@ impl Hasher { }; let num_batches = op_batches.len(); - if num_batches == 1 { - if is_memoized { - // copy trace if trace exists - self.trace.copy_trace(&mut state, start_row..end_row); - } else { + + // if the span block is encountered for the first time and it's trace is not memoized, + // we need to build the trace from scratch. + if !is_memoized { + if num_batches == 1 { // if there is only one batch to hash, we need only one permutation self.trace.append_permutation(&mut state, START, RETURN); - } - } else { - let mut row = start_row; - // if there is more than one batch, we need to process the first, the last, and the - // middle permutations a bit differently. Specifically, selector flags for the - // permutations need to be set as follows: - // - first permutation: init linear hash on the first row, and absorb the next - // operation batch on the last row. - // - middle permutations: continue hashing on the first row, and absorb the next - // operation batch on the last row. - // - last permutation: continue hashing on the first row, and return the result - // on the last row. - if is_memoized { - // copy trace if trace exists - self.trace - .copy_trace(&mut state, row..(row + HASH_CYCLE_LEN)); - row += HASH_CYCLE_LEN; } else { + // if there is more than one batch, we need to process the first, the last, and the + // middle permutations a bit differently. Specifically, selector flags for the + // permutations need to be set as follows: + // - first permutation: init linear hash on the first row, and absorb the next + // operation batch on the last row. + // - middle permutations: continue hashing on the first row, and absorb the next + // operation batch on the last row. + // - last permutation: continue hashing on the first row, and return the result + // on the last row. self.trace.append_permutation(&mut state, START, ABSORB); - } - let mut last_state = state; + for batch in op_batches.iter().take(num_batches - 1).skip(1) { + absorb_into_state(&mut state, batch.groups()); + // add the lookup for absorbing the next operation batch. + self.append_lookup(ABSORB_LABEL, ZERO, HasherLookupContext::Absorb); + self.trace.append_permutation(&mut state, CONTINUE, ABSORB); + } - for batch in op_batches.iter().take(num_batches - 1).skip(1) { - absorb_into_state(&mut state, batch.groups()); - // add the lookup for absorbing the next operation batch. - self.append_lookup( + absorb_into_state(&mut state, op_batches[num_batches - 1].groups()); + // add the lookup for absorbing the final operation batch. + self.append_lookup(ABSORB_LABEL, ZERO, HasherLookupContext::Absorb); + self.trace.append_permutation(&mut state, CONTINUE, RETURN); + } + self.insert_to_memoized_trace_map(addr, expected_hash); + } else if num_batches == 1 { + self.trace.copy_trace(&mut state, start_row..end_row); + } else { + for i in 1..num_batches - 1 { + // add the lookup for absorbing the next operation batch. Here we add the + // lookups before actually copying the memoized trace. + self.append_lookup_at( + self.trace_len() + i * HASH_CYCLE_LEN, ABSORB_LABEL, - last_state, ZERO, - HasherLookupContext::Absorb(state), + HasherLookupContext::Absorb, ); - - if is_memoized { - self.trace - .copy_trace(&mut state, row..(row + HASH_CYCLE_LEN)); - row += HASH_CYCLE_LEN; - } else { - self.trace.append_permutation(&mut state, CONTINUE, ABSORB); - } - - last_state = state; } - absorb_into_state(&mut state, op_batches[num_batches - 1].groups()); + self.trace.copy_trace(&mut state, start_row..end_row); // add the lookup for absorbing the final operation batch. - self.append_lookup( - ABSORB_LABEL, - last_state, - ZERO, - HasherLookupContext::Absorb(state), - ); - if is_memoized { - self.trace.copy_trace(&mut state, row..end_row); - } else { - self.trace.append_permutation(&mut state, CONTINUE, RETURN); - } + self.append_lookup(ABSORB_LABEL, ZERO, HasherLookupContext::Absorb); } // add the lookup for the hash result. - self.append_lookup(RETURN_LABEL, state, ZERO, HasherLookupContext::Return); + self.append_lookup(RETURN_LABEL, ZERO, HasherLookupContext::Return); - if !is_memoized { - self.insert_to_memoized_trace_map(addr, expected_hash); - } let result = get_digest(&state); let lookups = self.get_last_lookups(init_lookup_idx); @@ -502,7 +486,7 @@ impl Hasher { // add the lookup for the hash initialization if this is the beginning. let context = HasherLookupContext::Start; if let Some(label) = get_selector_context_label(init_selectors, context) { - self.append_lookup(label, state, Felt::new(*index), context); + self.append_lookup(label, Felt::new(*index), context); } // determine values for the node index column for this permutation. if the first selector @@ -530,7 +514,7 @@ impl Hasher { // add the lookup for the hash result if this is the end. let context = HasherLookupContext::Return; if let Some(label) = get_selector_context_label(final_selectors, context) { - self.append_lookup(label, state, Felt::new(*index), context); + self.append_lookup(label, Felt::new(*index), context); } get_digest(&state) From eae291c290f9f5cc73b9f5067ac19355946db30a Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Thu, 13 Oct 2022 11:55:02 +0000 Subject: [PATCH 4/5] fix: pacify rustfmt --- processor/src/chiplets/memory/mod.rs | 6 +++++- processor/src/trace/utils.rs | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/processor/src/chiplets/memory/mod.rs b/processor/src/chiplets/memory/mod.rs index a8527534f6..cf2aeada04 100644 --- a/processor/src/chiplets/memory/mod.rs +++ b/processor/src/chiplets/memory/mod.rs @@ -334,7 +334,11 @@ impl MemoryLookup { impl LookupTableRow for MemoryLookup { /// Reduces this row to a single field element in the field specified by E. This requires /// at least 9 alpha values. - fn to_value>(&self, _main_trace: &Matrix, alphas: &[E]) -> E { + fn to_value>( + &self, + _main_trace: &Matrix, + alphas: &[E], + ) -> E { let word_value = self .word .iter() diff --git a/processor/src/trace/utils.rs b/processor/src/trace/utils.rs index 7a9713ffe6..576544d306 100644 --- a/processor/src/trace/utils.rs +++ b/processor/src/trace/utils.rs @@ -90,7 +90,7 @@ pub trait LookupTableRow { pub fn build_lookup_table_row_values, R: LookupTableRow>( rows: &[R], main_trace: &Matrix, - rand_values: &[E], + rand_values: &[E], ) -> (Vec, Vec) { let mut row_values = unsafe { uninit_vector(rows.len()) }; let mut inv_row_values = unsafe { uninit_vector(rows.len()) }; From 756f3401ba97512c8f7b18deff688e19d038eac2 Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Fri, 14 Oct 2022 18:19:51 +0000 Subject: [PATCH 5/5] chore: reorder parameters for to_stack_value and to_mem_value --- processor/src/range/aux_trace.rs | 4 ++-- processor/src/range/request.rs | 8 ++++---- processor/src/trace/decoder/mod.rs | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/processor/src/range/aux_trace.rs b/processor/src/range/aux_trace.rs index e47aa81078..512de23734 100644 --- a/processor/src/range/aux_trace.rs +++ b/processor/src/range/aux_trace.rs @@ -177,7 +177,7 @@ impl AuxTraceBuilder { p1_idx = clk + 1; // update the intermediate values in the q column. - q[clk] = range_checks.to_stack_value(alphas, main_trace); + q[clk] = range_checks.to_stack_value(main_trace, alphas); // include the operation lookups in the running product. p1[p1_idx] = p1[clk] * inv_row_values[rc_user_op_idx]; @@ -208,7 +208,7 @@ impl AuxTraceBuilder { if let Some(range_check) = self.cycle_range_checks.get(&(row_idx as u32)) { // update the intermediate values in the q column. - q[row_idx] = range_check.to_stack_value(alphas, main_trace); + q[row_idx] = range_check.to_stack_value(main_trace, alphas); // include the operation lookups in the running product. p1[p1_idx] *= inv_row_values[rc_user_op_idx]; diff --git a/processor/src/range/request.rs b/processor/src/range/request.rs index 5d8e7b1f4e..58253e228b 100644 --- a/processor/src/range/request.rs +++ b/processor/src/range/request.rs @@ -60,8 +60,8 @@ impl CycleRangeChecks { /// element in the field specified by E. pub fn to_stack_value>( &self, - alphas: &[E], main_trace: &Matrix, + alphas: &[E], ) -> E { let mut value = E::ONE; @@ -76,8 +76,8 @@ impl CycleRangeChecks { /// element in the field specified by E. fn to_mem_value>( &self, - alphas: &[E], main_trace: &Matrix, + alphas: &[E], ) -> E { let mut value = E::ONE; @@ -97,8 +97,8 @@ impl LookupTableRow for CycleRangeChecks { main_trace: &Matrix, alphas: &[E], ) -> E { - let stack_value = self.to_stack_value(alphas, main_trace); - let mem_value = self.to_mem_value(alphas, main_trace); + let stack_value = self.to_stack_value(main_trace, alphas); + let mem_value = self.to_mem_value(main_trace, alphas); if stack_value != E::ONE { stack_value * mem_value diff --git a/processor/src/trace/decoder/mod.rs b/processor/src/trace/decoder/mod.rs index a153029199..1c4c03cdac 100644 --- a/processor/src/trace/decoder/mod.rs +++ b/processor/src/trace/decoder/mod.rs @@ -23,10 +23,10 @@ pub fn build_aux_columns>( let p1 = build_aux_col_p1(main_trace, aux_trace_hints, rand_elements); let p2 = build_aux_col_p2(main_trace, aux_trace_hints, rand_elements); let p3 = build_aux_col_p3( + main_trace, main_trace.num_rows(), aux_trace_hints, rand_elements, - main_trace, ); vec![p1, p2, p3] } @@ -231,10 +231,10 @@ fn build_aux_col_p2>( /// Builds the execution trace of the decoder's `p3` column which describes the state of the op /// group table via multiset checks. fn build_aux_col_p3>( + main_trace: &Matrix, trace_len: usize, aux_trace_hints: &AuxTraceHints, alphas: &[E], - main_trace: &Matrix, ) -> Vec { // allocate memory for the column and set the starting value to ONE let mut result = unsafe { uninit_vector(trace_len) };