Skip to content

Commit

Permalink
chore: split up decoder aux column builders into separate files (#1233)
Browse files Browse the repository at this point in the history
  • Loading branch information
bobbinth committed Feb 8, 2024
1 parent 0df5861 commit a7b5fc3
Show file tree
Hide file tree
Showing 6 changed files with 511 additions and 540 deletions.
204 changes: 204 additions & 0 deletions processor/src/decoder/aux_trace/block_hash_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
use super::{
AuxColumnBuilder, Felt, FieldElement, MainTrace, StarkField, DYN, END, HALT, JOIN, LOOP, ONE,
REPEAT, SPLIT,
};

// BLOCK HASH TABLE COLUMN BUILDER
// ================================================================================================

/// Builds the execution trace of the decoder's `p2` column which describes the state of the block
/// hash table via multiset checks.
#[derive(Default)]
pub struct BlockHashTableColumnBuilder {}

impl<E: FieldElement<BaseField = Felt>> AuxColumnBuilder<E> for BlockHashTableColumnBuilder {
fn init_responses(&self, main_trace: &MainTrace, alphas: &[E]) -> E {
let row_index = (0..main_trace.num_rows())
.find(|row| main_trace.get_op_code(*row) == Felt::from(HALT))
.expect("execution trace must include at least one occurrence of HALT");
let program_hash = main_trace.decoder_hasher_state_first_half(row_index);

// Computes the initialization value for the block hash table.
alphas[0]
+ alphas[2].mul_base(program_hash[0])
+ alphas[3].mul_base(program_hash[1])
+ alphas[4].mul_base(program_hash[2])
+ alphas[5].mul_base(program_hash[3])
}

/// Removes a row from the block hash table.
fn get_requests_at(&self, main_trace: &MainTrace, alphas: &[E], i: usize) -> E {
let op_code_felt = main_trace.get_op_code(i);
let op_code = op_code_felt.as_int() as u8;

let op_code_felt_next = main_trace.get_op_code(i + 1);
let op_code_next = op_code_felt_next.as_int() as u8;

match op_code {
END => get_block_hash_table_removal_multiplicand(main_trace, i, alphas, op_code_next),
_ => E::ONE,
}
}

/// Adds a row to the block hash table.
fn get_responses_at(&self, main_trace: &MainTrace, alphas: &[E], i: usize) -> E {
let op_code_felt = main_trace.get_op_code(i);
let op_code = op_code_felt.as_int() as u8;

match op_code {
JOIN => get_block_hash_table_inclusion_multiplicand_join(main_trace, i, alphas),
SPLIT => get_block_hash_table_inclusion_multiplicand_split(main_trace, i, alphas),
LOOP => get_block_hash_table_inclusion_multiplicand_loop(main_trace, i, alphas),
REPEAT => get_block_hash_table_inclusion_multiplicand_repeat(main_trace, i, alphas),
DYN => get_block_hash_table_inclusion_multiplicand_dyn(main_trace, i, alphas),
_ => E::ONE,
}
}
}

// HELPER FUNCTIONS
// ================================================================================================

/// Computes the multiplicand representing the removal of a row from the block hash table.
fn get_block_hash_table_removal_multiplicand<E: FieldElement<BaseField = Felt>>(
main_trace: &MainTrace,
i: usize,
alphas: &[E],
op_code_next: u8,
) -> E {
let a = main_trace.addr(i + 1);
let digest = main_trace.decoder_hasher_state_first_half(i);
let is_loop_body = main_trace.is_loop_body_flag(i);
let next_end_or_repeat =
if op_code_next == END || op_code_next == REPEAT || op_code_next == HALT {
E::ZERO
} else {
alphas[6]
};

alphas[0]
+ alphas[1].mul_base(a)
+ alphas[2].mul_base(digest[0])
+ alphas[3].mul_base(digest[1])
+ alphas[4].mul_base(digest[2])
+ alphas[5].mul_base(digest[3])
+ alphas[7].mul_base(is_loop_body)
+ next_end_or_repeat
}

/// Computes the multiplicand representing the inclusion of a new row representing a JOIN block
/// to the block hash table.
fn get_block_hash_table_inclusion_multiplicand_join<E: FieldElement<BaseField = Felt>>(
main_trace: &MainTrace,
i: usize,
alphas: &[E],
) -> E {
let a_prime = main_trace.addr(i + 1);
let state = main_trace.decoder_hasher_state(i);
let ch1 = alphas[0]
+ alphas[1].mul_base(a_prime)
+ alphas[2].mul_base(state[0])
+ alphas[3].mul_base(state[1])
+ alphas[4].mul_base(state[2])
+ alphas[5].mul_base(state[3]);
let ch2 = alphas[0]
+ alphas[1].mul_base(a_prime)
+ alphas[2].mul_base(state[4])
+ alphas[3].mul_base(state[5])
+ alphas[4].mul_base(state[6])
+ alphas[5].mul_base(state[7]);

(ch1 + alphas[6]) * ch2
}

/// Computes the multiplicand representing the inclusion of a new row representing a SPLIT block
/// to the block hash table.
fn get_block_hash_table_inclusion_multiplicand_split<E: FieldElement<BaseField = Felt>>(
main_trace: &MainTrace,
i: usize,
alphas: &[E],
) -> E {
let s0 = main_trace.stack_element(0, i);
let a_prime = main_trace.addr(i + 1);
let state = main_trace.decoder_hasher_state(i);

if s0 == ONE {
alphas[0]
+ alphas[1].mul_base(a_prime)
+ alphas[2].mul_base(state[0])
+ alphas[3].mul_base(state[1])
+ alphas[4].mul_base(state[2])
+ alphas[5].mul_base(state[3])
} else {
alphas[0]
+ alphas[1].mul_base(a_prime)
+ alphas[2].mul_base(state[4])
+ alphas[3].mul_base(state[5])
+ alphas[4].mul_base(state[6])
+ alphas[5].mul_base(state[7])
}
}

/// Computes the multiplicand representing the inclusion of a new row representing a LOOP block
/// to the block hash table.
fn get_block_hash_table_inclusion_multiplicand_loop<E: FieldElement<BaseField = Felt>>(
main_trace: &MainTrace,
i: usize,
alphas: &[E],
) -> E {
let s0 = main_trace.stack_element(0, i);

if s0 == ONE {
let a_prime = main_trace.addr(i + 1);
let state = main_trace.decoder_hasher_state(i);
alphas[0]
+ alphas[1].mul_base(a_prime)
+ alphas[2].mul_base(state[0])
+ alphas[3].mul_base(state[1])
+ alphas[4].mul_base(state[2])
+ alphas[5].mul_base(state[3])
+ alphas[7]
} else {
E::ONE
}
}

/// Computes the multiplicand representing the inclusion of a new row representing a REPEAT
/// to the block hash table.
fn get_block_hash_table_inclusion_multiplicand_repeat<E: FieldElement<BaseField = Felt>>(
main_trace: &MainTrace,
i: usize,
alphas: &[E],
) -> E {
let a_prime = main_trace.addr(i + 1);
let state = main_trace.decoder_hasher_state_first_half(i);

alphas[0]
+ alphas[1].mul_base(a_prime)
+ alphas[2].mul_base(state[0])
+ alphas[3].mul_base(state[1])
+ alphas[4].mul_base(state[2])
+ alphas[5].mul_base(state[3])
+ alphas[7]
}

/// Computes the multiplicand representing the inclusion of a new row representing a DYN block
/// to the block hash table.
fn get_block_hash_table_inclusion_multiplicand_dyn<E: FieldElement<BaseField = Felt>>(
main_trace: &MainTrace,
i: usize,
alphas: &[E],
) -> E {
let a_prime = main_trace.addr(i + 1);
let s0 = main_trace.stack_element(0, i);
let s1 = main_trace.stack_element(1, i);
let s2 = main_trace.stack_element(2, i);
let s3 = main_trace.stack_element(3, i);

alphas[0]
+ alphas[1].mul_base(a_prime)
+ alphas[2].mul_base(s3)
+ alphas[3].mul_base(s2)
+ alphas[4].mul_base(s1)
+ alphas[5].mul_base(s0)
}
150 changes: 150 additions & 0 deletions processor/src/decoder/aux_trace/block_stack_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use super::{
AuxColumnBuilder, Felt, FieldElement, MainTrace, StarkField, CALL, DYN, END, JOIN, LOOP, ONE,
RESPAN, SPAN, SPLIT, SYSCALL, ZERO,
};

// BLOCK STACK TABLE COLUMN BUILDER
// ================================================================================================

/// Builds the execution trace of the decoder's `p1` column which describes the state of the block
/// stack table via multiset checks.
#[derive(Default)]
pub struct BlockStackColumnBuilder {}

impl<E: FieldElement<BaseField = Felt>> AuxColumnBuilder<E> for BlockStackColumnBuilder {
/// Removes a row from the block stack table.
fn get_requests_at(&self, main_trace: &MainTrace, alphas: &[E], i: usize) -> E {
let op_code_felt = main_trace.get_op_code(i);
let op_code = op_code_felt.as_int() as u8;

match op_code {
RESPAN => get_block_stack_table_removal_multiplicand(main_trace, i, true, alphas),
END => get_block_stack_table_removal_multiplicand(main_trace, i, false, alphas),
_ => E::ONE,
}
}

/// Adds a row to the block stack table.
fn get_responses_at(&self, main_trace: &MainTrace, alphas: &[E], i: usize) -> E {
let op_code_felt = main_trace.get_op_code(i);
let op_code = op_code_felt.as_int() as u8;

match op_code {
JOIN | SPLIT | SPAN | DYN | LOOP | RESPAN | CALL | SYSCALL => {
get_block_stack_table_inclusion_multiplicand(main_trace, i, alphas, op_code)
}
_ => E::ONE,
}
}
}

// HELPER FUNCTIONS
// ================================================================================================

/// Computes the multiplicand representing the removal of a row from the block stack table.
fn get_block_stack_table_removal_multiplicand<E: FieldElement<BaseField = Felt>>(
main_trace: &MainTrace,
i: usize,
is_respan: bool,
alphas: &[E],
) -> E {
let block_id = main_trace.addr(i);
let parent_id = if is_respan {
main_trace.decoder_hasher_state_element(1, i + 1)
} else {
main_trace.addr(i + 1)
};
let is_loop = main_trace.is_loop_flag(i);

let elements = if main_trace.is_call_flag(i) == ONE || main_trace.is_syscall_flag(i) == ONE {
let parent_ctx = main_trace.ctx(i + 1);
let parent_fmp = main_trace.fmp(i + 1);
let parent_stack_depth = main_trace.stack_depth(i + 1);
let parent_next_overflow_addr = main_trace.parent_overflow_address(i + 1);
let parent_fn_hash = main_trace.fn_hash(i);

[
ONE,
block_id,
parent_id,
is_loop,
parent_ctx,
parent_fmp,
parent_stack_depth,
parent_next_overflow_addr,
parent_fn_hash[0],
parent_fn_hash[1],
parent_fn_hash[2],
parent_fn_hash[0],
]
} else {
let mut result = [ZERO; 12];
result[0] = ONE;
result[1] = block_id;
result[2] = parent_id;
result[3] = is_loop;
result
};

let mut value = E::ZERO;

for (&alpha, &element) in alphas.iter().zip(elements.iter()) {
value += alpha.mul_base(element);
}
value
}

/// Computes the multiplicand representing the inclusion of a new row to the block stack table.
fn get_block_stack_table_inclusion_multiplicand<E: FieldElement<BaseField = Felt>>(
main_trace: &MainTrace,
i: usize,
alphas: &[E],
op_code: u8,
) -> E {
let block_id = main_trace.addr(i + 1);
let parent_id = if op_code == RESPAN {
main_trace.decoder_hasher_state_element(1, i + 1)
} else {
main_trace.addr(i)
};
let is_loop = if op_code == LOOP {
main_trace.stack_element(0, i)
} else {
ZERO
};
let elements = if op_code == CALL || op_code == SYSCALL {
let parent_ctx = main_trace.ctx(i);
let parent_fmp = main_trace.fmp(i);
let parent_stack_depth = main_trace.stack_depth(i);
let parent_next_overflow_addr = main_trace.parent_overflow_address(i);
let parent_fn_hash = main_trace.decoder_hasher_state_first_half(i);
[
ONE,
block_id,
parent_id,
is_loop,
parent_ctx,
parent_fmp,
parent_stack_depth,
parent_next_overflow_addr,
parent_fn_hash[0],
parent_fn_hash[1],
parent_fn_hash[2],
parent_fn_hash[3],
]
} else {
let mut result = [ZERO; 12];
result[0] = ONE;
result[1] = block_id;
result[2] = parent_id;
result[3] = is_loop;
result
};

let mut value = E::ZERO;

for (&alpha, &element) in alphas.iter().zip(elements.iter()) {
value += alpha.mul_base(element);
}
value
}
Loading

0 comments on commit a7b5fc3

Please sign in to comment.