Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor HasherLookup #424

Merged
merged 5 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions core/src/chiplets/hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,23 @@ pub const ROW_COL_IDX: usize = NUM_SELECTORS;
/// The hasher state portion of the execution trace, located in 4 .. 16 columns.
pub const STATE_COL_RANGE: Range<usize> = create_range(ROW_COL_IDX + 1, STATE_WIDTH);

/// Number of field elements in the capacity portion of the hasher's state.
pub const CAPACITY_LEN: usize = STATE_WIDTH - RATE_LEN;

/// The capacity portion of the hasher state in the execution trace, located in 4 .. 8 columns.
pub const CAPACITY_COL_RANGE: Range<usize> = Range {
start: STATE_COL_RANGE.start,
end: STATE_COL_RANGE.start + CAPACITY_LEN,
};

/// Number of field elements in the rate portion of the hasher's state.
pub const RATE_LEN: usize = 8;

/// Number of field elements in the capacity portion of the hasher's state.
pub const CAPACITY_LEN: usize = STATE_WIDTH - RATE_LEN;
/// The rate portion of the hasher state in the execution trace, located in 8 .. 16 columns.
pub const RATE_COL_RANGE: Range<usize> = Range {
start: CAPACITY_COL_RANGE.end,
end: CAPACITY_COL_RANGE.end + RATE_LEN,
};

// The length of the output portion of the hash state.
pub const DIGEST_LEN: usize = 4;
Expand Down
11 changes: 11 additions & 0 deletions core/src/chiplets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ pub const HASHER_STATE_COL_RANGE: Range<usize> = Range {
start: HASHER_TRACE_OFFSET + hasher::STATE_COL_RANGE.start,
end: HASHER_TRACE_OFFSET + hasher::STATE_COL_RANGE.end,
};
/// The range of columns in the execution trace that contains the capacity portion of the hasher
/// state.
pub const HASHER_CAPACITY_COL_RANGE: Range<usize> = Range {
start: HASHER_TRACE_OFFSET + hasher::CAPACITY_COL_RANGE.start,
end: HASHER_TRACE_OFFSET + hasher::CAPACITY_COL_RANGE.end,
};
/// The range of columns in the execution trace that contains the rate portion of the hasher state.
pub const HASHER_RATE_COL_RANGE: Range<usize> = Range {
start: HASHER_TRACE_OFFSET + hasher::RATE_COL_RANGE.start,
end: HASHER_TRACE_OFFSET + hasher::RATE_COL_RANGE.end,
};
/// The index of the hasher's node index column in the execution trace.
pub const HASHER_NODE_INDEX_COL_IDX: usize = HASHER_STATE_COL_RANGE.end;

Expand Down
8 changes: 6 additions & 2 deletions processor/src/chiplets/bitwise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{
ChipletsBus, ExecutionError, Felt, FieldElement, LookupTableRow, StarkField, TraceFragment,
Vec, BITWISE_AND_LABEL, BITWISE_XOR_LABEL,
};
use crate::utils::get_trace_len;
use crate::{utils::get_trace_len, Matrix};
use vm_core::chiplets::bitwise::{
A_COL_IDX, A_COL_RANGE, BITWISE_AND, BITWISE_XOR, B_COL_IDX, B_COL_RANGE, OP_CYCLE_LEN,
OUTPUT_COL_IDX, PREV_OUTPUT_COL_IDX, TRACE_WIDTH,
Expand Down Expand Up @@ -261,7 +261,11 @@ impl BitwiseLookup {
impl LookupTableRow for BitwiseLookup {
/// Reduces this row to a single field element in the field specified by E. This requires
/// at least 5 alpha values.
fn to_value<E: FieldElement<BaseField = Felt>>(&self, alphas: &[E]) -> E {
fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
_main_trace: &Matrix<Felt>,
alphas: &[E],
) -> E {
alphas[0]
+ alphas[1].mul_base(self.op_id)
+ alphas[2].mul_base(self.a)
Expand Down
10 changes: 5 additions & 5 deletions processor/src/chiplets/bus/aux_trace.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use super::{ChipletsLookup, ChipletsLookupRow, Felt, FieldElement};
use crate::{
trace::{build_lookup_table_row_values, AuxColumnBuilder, LookupTableRow},
Vec,
Matrix, Vec,
};
use winterfell::Matrix;

// AUXILIARY TRACE BUILDER
// ================================================================================================
Expand Down Expand Up @@ -76,18 +75,19 @@ impl AuxColumnBuilder<ChipletsLookup, ChipletsLookupRow, u32> for AuxTraceBuilde
/// requests. Since responses are grouped by chiplet, the operation order for the requests and
/// responses will be permutations of each other rather than sharing the same order. Therefore,
/// the `row_values` and `inv_row_values` must be built separately.
fn build_row_values<E>(&self, _main_trace: &Matrix<Felt>, alphas: &[E]) -> (Vec<E>, Vec<E>)
fn build_row_values<E>(&self, main_trace: &Matrix<Felt>, alphas: &[E]) -> (Vec<E>, Vec<E>)
where
E: FieldElement<BaseField = Felt>,
{
// get the row values from the resonse rows
let row_values = self
.response_rows
.iter()
.map(|response| response.to_value(alphas))
.map(|response| response.to_value(main_trace, alphas))
.collect();
// get the inverse values from the request rows
let (_, inv_row_values) = build_lookup_table_row_values(&self.request_rows, alphas);
let (_, inv_row_values) =
build_lookup_table_row_values(&self.request_rows, main_trace, alphas);

(row_values, inv_row_values)
}
Expand Down
15 changes: 10 additions & 5 deletions processor/src/chiplets/bus/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::{
hasher::HasherLookup, BTreeMap, BitwiseLookup, Felt, FieldElement, LookupTableRow,
MemoryLookup, Vec,
};
use crate::Matrix;

mod aux_trace;
pub use aux_trace::AuxTraceBuilder;
Expand Down Expand Up @@ -210,14 +211,18 @@ pub(super) enum ChipletsLookupRow {
}

impl LookupTableRow for ChipletsLookupRow {
fn to_value<E: FieldElement<BaseField = Felt>>(&self, alphas: &[E]) -> E {
fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &Matrix<Felt>,
alphas: &[E],
) -> E {
match self {
ChipletsLookupRow::HasherMulti(lookups) => lookups
.iter()
.fold(E::ONE, |acc, row| acc * row.to_value(alphas)),
ChipletsLookupRow::Hasher(row) => row.to_value(alphas),
ChipletsLookupRow::Bitwise(row) => row.to_value(alphas),
ChipletsLookupRow::Memory(row) => row.to_value(alphas),
.fold(E::ONE, |acc, row| acc * row.to_value(main_trace, alphas)),
ChipletsLookupRow::Hasher(row) => row.to_value(main_trace, alphas),
ChipletsLookupRow::Bitwise(row) => row.to_value(main_trace, alphas),
ChipletsLookupRow::Memory(row) => row.to_value(main_trace, alphas),
}
}
}
12 changes: 9 additions & 3 deletions processor/src/chiplets/hasher/aux_trace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::{Felt, FieldElement, StarkField, Vec, Word};
use crate::trace::{AuxColumnBuilder, LookupTableRow};
use winterfell::Matrix;
use crate::{
trace::{AuxColumnBuilder, LookupTableRow},
Matrix,
};

// AUXILIARY TRACE BUILDER
// ================================================================================================
Expand Down Expand Up @@ -121,7 +123,11 @@ impl SiblingTableRow {
impl LookupTableRow for SiblingTableRow {
/// Reduces this row to a single field element in the field specified by E. This requires
/// at least 6 alpha values.
fn to_value<E: FieldElement<BaseField = Felt>>(&self, alphas: &[E]) -> E {
fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
_main_trace: &Matrix<Felt>,
alphas: &[E],
) -> E {
// when the least significant bit of the index is 0, the sibling will be in the 3rd word
// of the hasher state, and when the least significant bit is 1, it will be in the 2nd
// word. we compute the value in this way to make constraint evaluation a bit easier since
Expand Down
105 changes: 77 additions & 28 deletions processor/src/chiplets/hasher/lookups.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
use super::{Felt, FieldElement, HasherState, LookupTableRow, StarkField};
use vm_core::chiplets::hasher::{
CAPACITY_LEN, DIGEST_RANGE, LINEAR_HASH_LABEL, MP_VERIFY_LABEL, MR_UPDATE_NEW_LABEL,
MR_UPDATE_OLD_LABEL, RETURN_HASH_LABEL, RETURN_STATE_LABEL, STATE_WIDTH,
use core::ops::Range;

use super::{Felt, FieldElement, LookupTableRow, StarkField};
use crate::Matrix;
use vm_core::{
chiplets::{
hasher::{
CAPACITY_LEN, DIGEST_LEN, DIGEST_RANGE, LINEAR_HASH_LABEL, MP_VERIFY_LABEL,
MR_UPDATE_NEW_LABEL, MR_UPDATE_OLD_LABEL, RATE_LEN, RETURN_HASH_LABEL,
RETURN_STATE_LABEL, STATE_WIDTH,
},
HASHER_RATE_COL_RANGE, HASHER_STATE_COL_RANGE,
},
utils::collections::Vec,
};

// CONSTANTS
Expand All @@ -17,8 +27,7 @@ const NUM_HEADER_ALPHAS: usize = 4;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum HasherLookupContext {
Start,
// TODO: benchmark removing this and getting it from the trace instead
Absorb(HasherState),
Absorb,
tohrnii marked this conversation as resolved.
Show resolved Hide resolved
Return,
}

Expand All @@ -27,9 +36,6 @@ pub enum HasherLookupContext {
pub struct HasherLookup {
// unique label identifying the hash operation
label: u8,
// TODO: benchmark removing this and getting it from the trace instead
// hasher state
state: HasherState,
// row address in the Hasher table
addr: u32,
// node index
Expand All @@ -40,16 +46,9 @@ pub struct HasherLookup {

impl HasherLookup {
/// Creates a new HasherLookup.
pub(super) fn new(
label: u8,
state: HasherState,
addr: u32,
index: Felt,
context: HasherLookupContext,
) -> Self {
pub(super) fn new(label: u8, addr: u32, index: Felt, context: HasherLookupContext) -> Self {
Self {
label,
state,
addr,
index,
context,
Expand Down Expand Up @@ -81,7 +80,11 @@ impl HasherLookup {
impl LookupTableRow for HasherLookup {
/// Reduces this row to a single field element in the field specified by E. This requires
/// at least 16 alpha values.
fn to_value<E: FieldElement<BaseField = Felt>>(&self, alphas: &[E]) -> E {
fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &Matrix<Felt>,
alphas: &[E],
) -> E {
let header = self.get_header_value(&alphas[..NUM_HEADER_ALPHAS]);
// computing the rest of the value requires an alpha for each element in the [HasherState]
let alphas = &alphas[NUM_HEADER_ALPHAS..(NUM_HEADER_ALPHAS + STATE_WIDTH)];
Expand All @@ -90,8 +93,14 @@ impl LookupTableRow for HasherLookup {
HasherLookupContext::Start => {
if self.label == LINEAR_HASH_LABEL {
// include the entire state when initializing a linear hash.
header + build_value(alphas, &self.state)
header
+ build_value(
alphas,
&get_hasher_state_at(self.addr, main_trace, 0..STATE_WIDTH),
)
} else {
let state =
&get_hasher_state_at(self.addr, main_trace, CAPACITY_LEN..STATE_WIDTH);
assert!(
self.label == MR_UPDATE_OLD_LABEL
|| self.label == MR_UPDATE_NEW_LABEL
Expand All @@ -103,37 +112,45 @@ impl LookupTableRow for HasherLookup {
// by the index bit will be the leaf node, and the value must be computed in the
// same way in both cases.
let bit = (self.index.as_int() >> 1) & 1;
let left_word = build_value(&alphas[DIGEST_RANGE], &self.state[DIGEST_RANGE]);
let right_word =
build_value(&alphas[DIGEST_RANGE], &self.state[DIGEST_RANGE.end..]);
let left_word = build_value(&alphas[DIGEST_RANGE], &state[..DIGEST_LEN]);
let right_word = build_value(&alphas[DIGEST_RANGE], &state[DIGEST_LEN..]);

header + E::from(1 - bit).mul(left_word) + E::from(bit).mul(right_word)
}
}
HasherLookupContext::Absorb(next_state) => {
HasherLookupContext::Absorb => {
assert!(
self.label == LINEAR_HASH_LABEL,
"unrecognized hash operation"
);
let (curr_hasher_rate, next_hasher_rate) =
get_adjacent_hasher_rates(self.addr, main_trace);
// build the value from the delta of the hasher state's rate before and after the
// absorption of new elements.
let next_state_value =
build_value(&alphas[CAPACITY_LEN..], &next_state[CAPACITY_LEN..]);
let state_value = build_value(&alphas[CAPACITY_LEN..], &self.state[CAPACITY_LEN..]);
let next_state_value = build_value(&alphas[CAPACITY_LEN..], &next_hasher_rate);
let state_value = build_value(&alphas[CAPACITY_LEN..], &curr_hasher_rate);

header + next_state_value - state_value
}
HasherLookupContext::Return => {
if self.label == RETURN_STATE_LABEL {
// build the value from the result, which is the entire state
header + build_value(alphas, &self.state)
header
+ build_value(
alphas,
&get_hasher_state_at(self.addr, main_trace, 0..STATE_WIDTH),
)
} else {
assert!(
self.label == RETURN_HASH_LABEL,
"unrecognized hash operation"
);
// build the value from the result, which is the digest portion of the state
header + build_value(&alphas[DIGEST_RANGE], &self.state[DIGEST_RANGE])
header
+ build_value(
&alphas[DIGEST_RANGE],
&get_hasher_state_at(self.addr, main_trace, DIGEST_RANGE),
)
}
}
}
Expand All @@ -153,3 +170,35 @@ fn build_value<E: FieldElement<BaseField = Felt>>(alphas: &[E], elements: &[Felt
}
value
}

/// Returns the portion of the hasher state at the provided address that is within the provided
/// column range.
fn get_hasher_state_at(addr: u32, main_trace: &Matrix<Felt>, col_range: Range<usize>) -> Vec<Felt> {
let row = get_row_from_addr(addr);
col_range
.map(|col| main_trace.get(HASHER_STATE_COL_RANGE.start + col, row))
.collect::<Vec<Felt>>()
}

/// Returns the rate portion of the hasher state for the provided row and the next row.
fn get_adjacent_hasher_rates(
addr: u32,
main_trace: &Matrix<Felt>,
) -> ([Felt; RATE_LEN], [Felt; RATE_LEN]) {
let row = get_row_from_addr(addr);

let mut current = [Felt::ZERO; RATE_LEN];
let mut next = [Felt::ZERO; RATE_LEN];
for (idx, col_idx) in HASHER_RATE_COL_RANGE.enumerate() {
let column = main_trace.get_column(col_idx);
current[idx] = column[row];
next[idx] = column[row + 1];
}

(current, next)
}

/// Gets the row index from the specified row address.
fn get_row_from_addr(addr: u32) -> usize {
addr as usize - 1
}
Loading