Skip to content

Commit

Permalink
perf: Combine constraints for group keep_stack
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Jun 12, 2024
1 parent 4c30ffb commit 2ce5ff1
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 72 deletions.
2 changes: 1 addition & 1 deletion specification/src/arithmetization-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ In order to gauge the runtime cost for this step, the following table provides e
<!-- auto-gen info start tasm_air_evaluation_cost -->
| Processor | Op Stack | RAM |
|----------:|---------:|------:|
| 35891 | 66815 | 23667 |
| 34409 | 63859 | 22590 |
<!-- auto-gen info stop tasm_air_evaluation_cost -->
158 changes: 87 additions & 71 deletions triton-vm/src/table/processor_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1123,21 +1123,38 @@ impl ExtProcessorTable {
circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
n: usize,
) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
let curr_base_row = |col: ProcessorBaseTableColumn| {
assert!(n <= NUM_OP_STACK_REGISTERS);

let curr_row = |col: ProcessorBaseTableColumn| {
circuit_builder.input(CurrentBaseRow(col.master_base_table_index()))
};
let next_base_row = |col: ProcessorBaseTableColumn| {
let next_row = |col: ProcessorBaseTableColumn| {
circuit_builder.input(NextBaseRow(col.master_base_table_index()))
};

let all_but_n_top_elements_remain = (n..NUM_OP_STACK_REGISTERS)
let stack = (0..OpStackElement::COUNT)
.map(ProcessorTable::op_stack_column_by_index)
.map(|sti| next_base_row(sti) - curr_base_row(sti))
.collect_vec();
let op_stack_perm_arg_remains =
Self::instruction_group_keep_op_stack_height(circuit_builder);
let next_stack = stack.iter().map(|&st| next_row(st)).collect_vec();
let curr_stack = stack.iter().map(|&st| curr_row(st)).collect_vec();

let compress_stack_except_top_n = |stack: Vec<_>| -> ConstraintCircuitMonad<_> {
assert_eq!(NUM_OP_STACK_REGISTERS, stack.len());
let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i));
stack
.into_iter()
.enumerate()
.skip(n)
.map(|(i, st)| weight(i) * st)
.sum()
};

[all_but_n_top_elements_remain, op_stack_perm_arg_remains].concat()
let all_but_n_top_elements_remain =
compress_stack_except_top_n(next_stack) - compress_stack_except_top_n(curr_stack);

let mut constraints = Self::instruction_group_keep_op_stack_height(circuit_builder);
constraints.push(all_but_n_top_elements_remain);
constraints
}

/// Op stack does not change, _i.e._, all stack elements persist
Expand All @@ -1153,24 +1170,20 @@ impl ExtProcessorTable {
fn instruction_group_keep_op_stack_height(
circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
let curr_base_row = |col: ProcessorBaseTableColumn| {
circuit_builder.input(CurrentBaseRow(col.master_base_table_index()))
};
let next_base_row = |col: ProcessorBaseTableColumn| {
circuit_builder.input(NextBaseRow(col.master_base_table_index()))
};
let curr_ext_row = |col: ProcessorExtTableColumn| {
circuit_builder.input(CurrentExtRow(col.master_ext_table_index()))
};
let next_ext_row = |col: ProcessorExtTableColumn| {
circuit_builder.input(NextExtRow(col.master_ext_table_index()))
};
vec![
// permutation argument accumulator does not change
next_ext_row(OpStackTablePermArg) - curr_ext_row(OpStackTablePermArg),
// op stack pointer does not change
next_base_row(OpStackPointer) - curr_base_row(OpStackPointer),
]
let op_stack_pointer_curr =
circuit_builder.input(CurrentBaseRow(OpStackPointer.master_base_table_index()));
let op_stack_pointer_next =
circuit_builder.input(NextBaseRow(OpStackPointer.master_base_table_index()));
let osp_remains_unchanged = op_stack_pointer_next - op_stack_pointer_curr;

let op_stack_table_perm_arg_curr =
circuit_builder.input(CurrentExtRow(OpStackTablePermArg.master_ext_table_index()));
let op_stack_table_perm_arg_next =
circuit_builder.input(NextExtRow(OpStackTablePermArg.master_ext_table_index()));
let perm_arg_remains_unchanged =
op_stack_table_perm_arg_next - op_stack_table_perm_arg_curr;

vec![osp_remains_unchanged, perm_arg_remains_unchanged]
}

fn instruction_group_grow_op_stack_and_top_two_elements_unconstrained(
Expand Down Expand Up @@ -1510,46 +1523,6 @@ impl ExtProcessorTable {
.concat()
}

/// Compute the randomly-weighted linear combination of the supplied stack
/// elements using the first `stack.len()` [challenges] as weights.
///
/// # Panics
///
/// Panics if the supplied stack is larger than [`OpStackElement::COUNT`].
///
/// [challenges]: StackWeight0
fn compress_stack(
circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
stack: Vec<ConstraintCircuitMonad<DualRowIndicator>>,
) -> ConstraintCircuitMonad<DualRowIndicator> {
assert!(stack.len() <= OpStackElement::COUNT);
let challenges = [
StackWeight0,
StackWeight1,
StackWeight2,
StackWeight3,
StackWeight4,
StackWeight5,
StackWeight6,
StackWeight7,
StackWeight8,
StackWeight9,
StackWeight10,
StackWeight11,
StackWeight12,
StackWeight13,
StackWeight14,
StackWeight15,
]
.map(|ch| circuit_builder.challenge(ch));

challenges
.into_iter()
.zip(stack)
.map(|(weight, st)| weight * st)
.sum()
}

fn instruction_dup(
circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
Expand Down Expand Up @@ -1597,9 +1570,11 @@ impl ExtProcessorTable {

let next_stack = stack.iter().map(|&st| next_row(st)).collect_vec();
let curr_stack_with_swapped_i = |i| stack_with_swapped_i(i).map(curr_row).collect_vec();
let compress = |stack: Vec<_>| {
let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> {
assert_eq!(OpStackElement::COUNT, stack.len());
Self::compress_stack(circuit_builder, stack)
let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i));
let enumerated_stack = stack.into_iter().enumerate();
enumerated_stack.map(|(i, st)| weight(i) * st).sum()
};

let next_stack_is_current_stack_with_swapped_i = |i| {
Expand Down Expand Up @@ -2999,9 +2974,11 @@ impl ExtProcessorTable {
let new_stack = stack().dropping_back(n).map(next_row).collect_vec();
let old_stack_with_top_n_removed = stack().skip(n).map(curr_row).collect_vec();

let compress = |stack: Vec<_>| {
let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> {
assert_eq!(OpStackElement::COUNT - n, stack.len());
Self::compress_stack(circuit_builder, stack)
let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i));
let enumerated_stack = stack.into_iter().enumerate();
enumerated_stack.map(|(i, st)| weight(i) * st).sum()
};
let compressed_new_stack = compress(new_stack);
let compressed_old_stack = compress(old_stack_with_top_n_removed);
Expand Down Expand Up @@ -3033,9 +3010,11 @@ impl ExtProcessorTable {
let new_stack = stack().skip(n).map(next_row).collect_vec();
let old_stack_with_top_n_added = stack().map(curr_row).dropping_back(n).collect_vec();

let compress = |stack: Vec<_>| {
let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> {
assert_eq!(OpStackElement::COUNT - n, stack.len());
Self::compress_stack(circuit_builder, stack)
let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i));
let enumerated_stack = stack.into_iter().enumerate();
enumerated_stack.map(|(i, st)| weight(i) * st).sum()
};
let compressed_new_stack = compress(new_stack);
let compressed_old_stack = compress(old_stack_with_top_n_added);
Expand Down Expand Up @@ -3732,6 +3711,28 @@ impl ExtProcessorTable {
+ no_update_summand
}

fn stack_weight_by_index(i: usize) -> ChallengeId {
match i {
0 => StackWeight0,
1 => StackWeight1,
2 => StackWeight2,
3 => StackWeight3,
4 => StackWeight4,
5 => StackWeight5,
6 => StackWeight6,
7 => StackWeight7,
8 => StackWeight8,
9 => StackWeight9,
10 => StackWeight10,
11 => StackWeight11,
12 => StackWeight12,
13 => StackWeight13,
14 => StackWeight14,
15 => StackWeight15,
i => panic!("Op Stack weight index must be in [0, 15], not {i}."),
}
}

pub fn transition_constraints(
circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
Expand Down Expand Up @@ -4661,6 +4662,21 @@ pub(crate) mod tests {
let _ = ProcessorTable::op_stack_column_by_index(index);
}

#[test]
fn can_get_stack_weight_for_in_range_index() {
for index in 0..OpStackElement::COUNT {
let _ = ExtProcessorTable::stack_weight_by_index(index);
}
}

#[proptest]
#[should_panic(expected = "[0, 15]")]
fn cannot_get_stack_weight_for_out_of_range_index(
#[strategy(OpStackElement::COUNT..)] index: usize,
) {
let _ = ExtProcessorTable::stack_weight_by_index(index);
}

#[proptest]
fn constructing_factor_for_op_stack_table_running_product_never_panics(
#[strategy(vec(arb(), BASE_WIDTH))] previous_row: Vec<BFieldElement>,
Expand Down

0 comments on commit 2ce5ff1

Please sign in to comment.