Skip to content

Commit

Permalink
refactor: Deref MainTrace type to ColMatrix (#1214)
Browse files Browse the repository at this point in the history
  • Loading branch information
iammadab authored and bobbinth committed Feb 8, 2024
1 parent 3c09bdf commit 74c2032
Show file tree
Hide file tree
Showing 17 changed files with 79 additions and 60 deletions.
1 change: 1 addition & 0 deletions air/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ harness = false
[features]
default = ["std"]
std = ["vm-core/std", "winter-air/std"]
internals = []

[dependencies]
vm-core = { package = "miden-core", path = "../core", version = "0.8", default-features = false }
Expand Down
28 changes: 23 additions & 5 deletions air/src/trace/main_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ use super::{
CHIPLETS_OFFSET, CLK_COL_IDX, CTX_COL_IDX, DECODER_TRACE_OFFSET, FMP_COL_IDX, FN_HASH_OFFSET,
STACK_TRACE_OFFSET,
};
use core::ops::Range;
use core::ops::{Deref, Range};
#[cfg(any(test, feature = "internals"))]
use vm_core::utils::collections::Vec;
use vm_core::{utils::range, Felt, ONE, ZERO};

// CONSTANTS
Expand All @@ -28,12 +30,20 @@ const DECODER_HASHER_RANGE: Range<usize> =
// HELPER STRUCT AND METHODS
// ================================================================================================

pub struct MainTrace<'a> {
columns: &'a ColMatrix<Felt>,
pub struct MainTrace {
columns: ColMatrix<Felt>,
}

impl<'a> MainTrace<'a> {
pub fn new(main_trace: &'a ColMatrix<Felt>) -> Self {
impl Deref for MainTrace {
type Target = ColMatrix<Felt>;

fn deref(&self) -> &Self::Target {
&self.columns
}
}

impl MainTrace {
pub fn new(main_trace: ColMatrix<Felt>) -> Self {
Self {
columns: main_trace,
}
Expand All @@ -43,6 +53,14 @@ impl<'a> MainTrace<'a> {
self.columns.num_rows()
}

#[cfg(any(test, feature = "internals"))]
pub fn get_column_range(&self, range: Range<usize>) -> Vec<Vec<Felt>> {
range.fold(vec![], |mut acc, col_idx| {
acc.push(self.get_column(col_idx).to_vec());
acc
})
}

// SYSTEM COLUMNS
// --------------------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion processor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ doctest = false
[features]
concurrent = ["std", "winter-prover/concurrent"]
default = ["std"]
internals = []
internals = ["miden-air/internals"]
std = ["tracing/attributes", "vm-core/std", "winter-prover/std"]

[dependencies]
Expand Down
4 changes: 2 additions & 2 deletions processor/src/chiplets/aux_trace/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{super::trace::AuxColumnBuilder, ColMatrix, Felt, FieldElement, StarkField, Vec};
use super::{super::trace::AuxColumnBuilder, Felt, FieldElement, StarkField, Vec};

use miden_air::trace::{
chiplets::{
Expand Down Expand Up @@ -56,7 +56,7 @@ impl AuxTraceBuilder {
/// provided by chiplets in the Chiplets module.
pub fn build_aux_columns<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &ColMatrix<Felt>,
main_trace: &MainTrace,
rand_elements: &[E],
) -> Vec<Vec<E>> {
let v_table_col_builder = ChipletsVTableColBuilder::default();
Expand Down
5 changes: 2 additions & 3 deletions processor/src/chiplets/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::system::ContextId;

use super::{
crypto::MerklePath, utils, BTreeMap, ChipletsTrace, ColMatrix, ExecutionError, Felt,
FieldElement, RangeChecker, StarkField, TraceFragment, Vec, Word, CHIPLETS_WIDTH, EMPTY_WORD,
ONE, ZERO,
crypto::MerklePath, utils, BTreeMap, ChipletsTrace, ExecutionError, Felt, FieldElement,
RangeChecker, StarkField, TraceFragment, Vec, Word, CHIPLETS_WIDTH, EMPTY_WORD, ONE, ZERO,
};
use miden_air::trace::chiplets::hasher::{Digest, HasherState};
use vm_core::{code_blocks::OpBatch, Kernel};
Expand Down
10 changes: 5 additions & 5 deletions processor/src/chiplets/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
utils::get_trace_len, CodeBlock, DefaultHost, ExecutionOptions, ExecutionTrace, Kernel,
Operation, Process, StackInputs, Vec,
CodeBlock, DefaultHost, ExecutionOptions, ExecutionTrace, Kernel, Operation, Process,
StackInputs, Vec,
};
use miden_air::trace::{
chiplets::{
Expand Down Expand Up @@ -117,11 +117,11 @@ fn build_trace(
process.execute_code_block(&program, &CodeBlockTable::default()).unwrap();

let (trace, _, _) = ExecutionTrace::test_finalize_trace(process);
let trace_len = get_trace_len(&trace) - ExecutionTrace::NUM_RAND_ROWS;
let trace_len = trace.num_rows() - ExecutionTrace::NUM_RAND_ROWS;

(
trace[CHIPLETS_RANGE]
.to_vec()
trace
.get_column_range(CHIPLETS_RANGE)
.try_into()
.expect("failed to convert vector to array"),
trace_len,
Expand Down
3 changes: 1 addition & 2 deletions processor/src/decoder/auxiliary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use miden_air::trace::{
};

use vm_core::{crypto::hash::RpoDigest, FieldElement, Operation};
use winter_prover::matrix::ColMatrix;

// CONSTANTS
// ================================================================================================
Expand Down Expand Up @@ -39,7 +38,7 @@ impl AuxTraceBuilder {
/// stack, block hash, and op group tables respectively.
pub fn build_aux_columns<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &ColMatrix<Felt>,
main_trace: &MainTrace,
rand_elements: &[E],
) -> Vec<Vec<E>> {
let block_stack_column_builder = BlockStackColumnBuilder::default();
Expand Down
25 changes: 12 additions & 13 deletions processor/src/decoder/tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::{
super::{
utils::get_trace_len, ExecutionOptions, ExecutionTrace, Felt, Kernel, Operation, Process,
StackInputs, Word,
ExecutionOptions, ExecutionTrace, Felt, Kernel, Operation, Process, StackInputs, Word,
},
build_op_group,
};
Expand Down Expand Up @@ -1202,11 +1201,11 @@ fn build_trace(stack_inputs: &[u64], program: &CodeBlock) -> (DecoderTrace, usiz
process.execute_code_block(program, &CodeBlockTable::default()).unwrap();

let (trace, _, _) = ExecutionTrace::test_finalize_trace(process);
let trace_len = get_trace_len(&trace) - ExecutionTrace::NUM_RAND_ROWS;
let trace_len = trace.num_rows() - ExecutionTrace::NUM_RAND_ROWS;

(
trace[DECODER_TRACE_RANGE]
.to_vec()
trace
.get_column_range(DECODER_TRACE_RANGE)
.try_into()
.expect("failed to convert vector to array"),
trace_len,
Expand All @@ -1230,11 +1229,11 @@ fn build_dyn_trace(
process.execute_code_block(program, &cb_table).unwrap();

let (trace, _, _) = ExecutionTrace::test_finalize_trace(process);
let trace_len = get_trace_len(&trace) - ExecutionTrace::NUM_RAND_ROWS;
let trace_len = trace.num_rows() - ExecutionTrace::NUM_RAND_ROWS;

(
trace[DECODER_TRACE_RANGE]
.to_vec()
trace
.get_column_range(DECODER_TRACE_RANGE)
.try_into()
.expect("failed to convert vector to array"),
trace_len,
Expand Down Expand Up @@ -1264,15 +1263,15 @@ fn build_call_trace(
process.execute_code_block(program, &cb_table).unwrap();

let (trace, _, _) = ExecutionTrace::test_finalize_trace(process);
let trace_len = get_trace_len(&trace) - ExecutionTrace::NUM_RAND_ROWS;
let trace_len = trace.num_rows() - ExecutionTrace::NUM_RAND_ROWS;

let sys_trace = trace[SYS_TRACE_RANGE]
.to_vec()
let sys_trace = trace
.get_column_range(SYS_TRACE_RANGE)
.try_into()
.expect("failed to convert vector to array");

let decoder_trace = trace[DECODER_TRACE_RANGE]
.to_vec()
let decoder_trace = trace
.get_column_range(DECODER_TRACE_RANGE)
.try_into()
.expect("failed to convert vector to array");

Expand Down
7 changes: 4 additions & 3 deletions processor/src/range/aux_trace.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{uninit_vector, BTreeMap, ColMatrix, Felt, FieldElement, Vec, NUM_RAND_ROWS};
use super::{uninit_vector, BTreeMap, Felt, FieldElement, Vec, NUM_RAND_ROWS};
use miden_air::trace::main_trace::MainTrace;
use miden_air::trace::range::{M_COL_IDX, V_COL_IDX};
use vm_core::StarkField;

Expand Down Expand Up @@ -42,7 +43,7 @@ impl AuxTraceBuilder {
/// requested by the Stack and Memory processors.
pub fn build_aux_columns<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &ColMatrix<Felt>,
main_trace: &MainTrace,
rand_elements: &[E],
) -> Vec<Vec<E>> {
let b_range = self.build_aux_col_b_range(main_trace, rand_elements);
Expand All @@ -53,7 +54,7 @@ impl AuxTraceBuilder {
/// check lookups performed by user operations match those executed by the Range Checker.
fn build_aux_col_b_range<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &ColMatrix<Felt>,
main_trace: &MainTrace,
rand_elements: &[E],
) -> Vec<E> {
// run batch inversion on the lookup values
Expand Down
4 changes: 2 additions & 2 deletions processor/src/range/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
trace::NUM_RAND_ROWS, utils::uninit_vector, BTreeMap, ColMatrix, Felt, FieldElement,
RangeCheckTrace, Vec, ZERO,
trace::NUM_RAND_ROWS, utils::uninit_vector, BTreeMap, Felt, FieldElement, RangeCheckTrace, Vec,
ZERO,
};

mod aux_trace;
Expand Down
8 changes: 4 additions & 4 deletions processor/src/range/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl CycleRangeChecks {
/// element in the field specified by E.
pub fn to_stack_value<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &ColMatrix<Felt>,
main_trace: &MainTrace,
alphas: &[E],
) -> E {
let mut value = E::ONE;
Expand All @@ -70,7 +70,7 @@ impl CycleRangeChecks {
/// element in the field specified by E.
fn to_mem_value<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &ColMatrix<Felt>,
main_trace: &MainTrace,
alphas: &[E],
) -> E {
let mut value = E::ONE;
Expand All @@ -88,7 +88,7 @@ impl LookupTableRow for CycleRangeChecks {
/// at least 1 alpha value. Includes all values included at this cycle from all processors.
fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &ColMatrix<Felt>,
main_trace: &MainTrace,
alphas: &[E],
) -> E {
let stack_value = self.to_stack_value(main_trace, alphas);
Expand All @@ -115,7 +115,7 @@ impl LookupTableRow for RangeCheckRequest {
/// at least 1 alpha value.
fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
_main_trace: &ColMatrix<Felt>,
_main_trace: &MainTrace,
alphas: &[E],
) -> E {
let alpha: E = alphas[0];
Expand Down
4 changes: 2 additions & 2 deletions processor/src/stack/aux_trace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::trace::AuxColumnBuilder;

use super::{ColMatrix, Felt, FieldElement, OverflowTableRow, Vec};
use super::{Felt, FieldElement, OverflowTableRow, Vec};
use miden_air::trace::main_trace::MainTrace;

// AUXILIARY TRACE BUILDER
Expand All @@ -20,7 +20,7 @@ impl AuxTraceBuilder {
/// column p1 describing states of the stack overflow table.
pub fn build_aux_columns<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &ColMatrix<Felt>,
main_trace: &MainTrace,
rand_elements: &[E],
) -> Vec<Vec<E>> {
let p1 = self.build_aux_column(main_trace, rand_elements);
Expand Down
3 changes: 1 addition & 2 deletions processor/src/stack/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use super::{
BTreeMap, ColMatrix, Felt, FieldElement, StackInputs, StackOutputs, Vec, ONE,
STACK_TRACE_WIDTH, ZERO,
BTreeMap, Felt, FieldElement, StackInputs, StackOutputs, Vec, ONE, STACK_TRACE_WIDTH, ZERO,
};
use core::cmp;
use vm_core::{stack::STACK_TOP_SIZE, Word, WORD_SIZE};
Expand Down
13 changes: 8 additions & 5 deletions processor/src/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::{
stack::AuxTraceBuilder as StackAuxTraceBuilder, ColMatrix, Digest, Felt, FieldElement, Host,
Process, StackTopState, Vec,
};
use miden_air::trace::main_trace::MainTrace;
use miden_air::trace::{
decoder::{NUM_USER_OP_HELPERS, USER_OP_HELPERS_OFFSET},
AUX_TRACE_RAND_ELEMENTS, AUX_TRACE_WIDTH, DECODER_TRACE_OFFSET, MIN_TRACE_LEN,
Expand Down Expand Up @@ -50,7 +51,7 @@ pub struct AuxTraceBuilders {
pub struct ExecutionTrace {
meta: Vec<u8>,
layout: TraceLayout,
main_trace: ColMatrix<Felt>,
main_trace: MainTrace,
aux_trace_builders: AuxTraceBuilders,
program_info: ProgramInfo,
stack_outputs: StackOutputs,
Expand Down Expand Up @@ -86,8 +87,8 @@ impl ExecutionTrace {
Self {
meta: Vec::new(),
layout: TraceLayout::new(TRACE_WIDTH, [AUX_TRACE_WIDTH], [AUX_TRACE_RAND_ELEMENTS]),
main_trace: ColMatrix::new(main_trace),
aux_trace_builders: aux_trace_hints,
main_trace,
program_info,
stack_outputs,
trace_len_summary,
Expand Down Expand Up @@ -173,7 +174,7 @@ impl ExecutionTrace {
#[cfg(test)]
pub fn test_finalize_trace<H>(
process: Process<H>,
) -> (Vec<Vec<Felt>>, AuxTraceBuilders, TraceLenSummary)
) -> (MainTrace, AuxTraceBuilders, TraceLenSummary)
where
H: Host,
{
Expand Down Expand Up @@ -276,7 +277,7 @@ impl Trace for ExecutionTrace {
fn finalize_trace<H>(
process: Process<H>,
mut rng: RpoRandomCoin,
) -> (Vec<Vec<Felt>>, AuxTraceBuilders, TraceLenSummary)
) -> (MainTrace, AuxTraceBuilders, TraceLenSummary)
where
H: Host,
{
Expand Down Expand Up @@ -341,5 +342,7 @@ where
chiplets: chiplets_trace.aux_builder,
};

(trace, aux_trace_hints, trace_len_summary)
let main_trace = MainTrace::new(ColMatrix::new(trace));

(main_trace, aux_trace_hints, trace_len_summary)
}
5 changes: 3 additions & 2 deletions processor/src/trace/tests/hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use super::{
ZERO,
};

use crate::{ColMatrix, StackInputs};
use crate::StackInputs;
use miden_air::trace::main_trace::MainTrace;
use miden_air::trace::{chiplets::hasher::P1_COL_IDX, AUX_TRACE_RAND_ELEMENTS};
use vm_core::{
crypto::merkle::{MerkleStore, MerkleTree, NodeIndex},
Expand Down Expand Up @@ -188,7 +189,7 @@ impl SiblingTableRow {
/// at least 6 alpha values.
pub fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
_main_trace: &ColMatrix<Felt>,
_main_trace: &MainTrace,
alphas: &[E],
) -> E {
// when the least significant bit of the index is 0, the sibling will be in the 3rd word
Expand Down
12 changes: 5 additions & 7 deletions processor/src/trace/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use vm_core::{utils::uninit_vector, FieldElement};

#[cfg(test)]
use vm_core::{utils::ToElements, Operation};
use winter_prover::matrix::ColMatrix;

// TRACE FRAGMENT
// ================================================================================================
Expand Down Expand Up @@ -221,19 +220,18 @@ pub trait AuxColumnBuilder<E: FieldElement<BaseField = Felt>> {
}

/// Builds the chiplets bus auxiliary trace column.
fn build_aux_column(&self, main_trace: &ColMatrix<Felt>, alphas: &[E]) -> Vec<E> {
let main_trace = MainTrace::new(main_trace);
fn build_aux_column(&self, main_trace: &MainTrace, alphas: &[E]) -> Vec<E> {
let mut responses_prod: Vec<E> = unsafe { uninit_vector(main_trace.num_rows()) };
let mut requests: Vec<E> = unsafe { uninit_vector(main_trace.num_rows()) };

responses_prod[0] = self.init_responses(&main_trace, alphas);
requests[0] = self.init_requests(&main_trace, alphas);
responses_prod[0] = self.init_responses(main_trace, alphas);
requests[0] = self.init_requests(main_trace, alphas);

let mut requests_running_prod = E::ONE;
for row_idx in 0..main_trace.num_rows() - 1 {
responses_prod[row_idx + 1] =
responses_prod[row_idx] * self.get_responses_at(&main_trace, alphas, row_idx);
requests[row_idx + 1] = self.get_requests_at(&main_trace, alphas, row_idx);
responses_prod[row_idx] * self.get_responses_at(main_trace, alphas, row_idx);
requests[row_idx + 1] = self.get_requests_at(main_trace, alphas, row_idx);
requests_running_prod *= requests[row_idx + 1];
}

Expand Down
Loading

0 comments on commit 74c2032

Please sign in to comment.