From 2ce5ff15da8c8452a72b207f689ec0d24e21eb74 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Fri, 7 Jun 2024 22:22:38 +0200 Subject: [PATCH] perf: Combine constraints for group `keep_stack` --- specification/src/arithmetization-overview.md | 2 +- triton-vm/src/table/processor_table.rs | 158 ++++++++++-------- 2 files changed, 88 insertions(+), 72 deletions(-) diff --git a/specification/src/arithmetization-overview.md b/specification/src/arithmetization-overview.md index 11d4cba2e..99c6b3954 100644 --- a/specification/src/arithmetization-overview.md +++ b/specification/src/arithmetization-overview.md @@ -71,5 +71,5 @@ In order to gauge the runtime cost for this step, the following table provides e | Processor | Op Stack | RAM | |----------:|---------:|------:| -| 35891 | 66815 | 23667 | +| 34409 | 63859 | 22590 | diff --git a/triton-vm/src/table/processor_table.rs b/triton-vm/src/table/processor_table.rs index 13ab6729f..357dd760d 100644 --- a/triton-vm/src/table/processor_table.rs +++ b/triton-vm/src/table/processor_table.rs @@ -1123,21 +1123,38 @@ impl ExtProcessorTable { circuit_builder: &ConstraintCircuitBuilder, n: usize, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { + assert!(n <= NUM_OP_STACK_REGISTERS); + + let curr_row = |col: ProcessorBaseTableColumn| { circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) }; - let next_base_row = |col: ProcessorBaseTableColumn| { + let next_row = |col: ProcessorBaseTableColumn| { circuit_builder.input(NextBaseRow(col.master_base_table_index())) }; - let all_but_n_top_elements_remain = (n..NUM_OP_STACK_REGISTERS) + let stack = (0..OpStackElement::COUNT) .map(ProcessorTable::op_stack_column_by_index) - .map(|sti| next_base_row(sti) - curr_base_row(sti)) .collect_vec(); - let op_stack_perm_arg_remains = - Self::instruction_group_keep_op_stack_height(circuit_builder); + let next_stack = stack.iter().map(|&st| next_row(st)).collect_vec(); + let curr_stack = stack.iter().map(|&st| curr_row(st)).collect_vec(); + + let compress_stack_except_top_n = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { + assert_eq!(NUM_OP_STACK_REGISTERS, stack.len()); + let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i)); + stack + .into_iter() + .enumerate() + .skip(n) + .map(|(i, st)| weight(i) * st) + .sum() + }; - [all_but_n_top_elements_remain, op_stack_perm_arg_remains].concat() + let all_but_n_top_elements_remain = + compress_stack_except_top_n(next_stack) - compress_stack_except_top_n(curr_stack); + + let mut constraints = Self::instruction_group_keep_op_stack_height(circuit_builder); + constraints.push(all_but_n_top_elements_remain); + constraints } /// Op stack does not change, _i.e._, all stack elements persist @@ -1153,24 +1170,20 @@ impl ExtProcessorTable { fn instruction_group_keep_op_stack_height( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - let curr_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) - }; - let next_ext_row = |col: ProcessorExtTableColumn| { - circuit_builder.input(NextExtRow(col.master_ext_table_index())) - }; - vec![ - // permutation argument accumulator does not change - next_ext_row(OpStackTablePermArg) - curr_ext_row(OpStackTablePermArg), - // op stack pointer does not change - next_base_row(OpStackPointer) - curr_base_row(OpStackPointer), - ] + let op_stack_pointer_curr = + circuit_builder.input(CurrentBaseRow(OpStackPointer.master_base_table_index())); + let op_stack_pointer_next = + circuit_builder.input(NextBaseRow(OpStackPointer.master_base_table_index())); + let osp_remains_unchanged = op_stack_pointer_next - op_stack_pointer_curr; + + let op_stack_table_perm_arg_curr = + circuit_builder.input(CurrentExtRow(OpStackTablePermArg.master_ext_table_index())); + let op_stack_table_perm_arg_next = + circuit_builder.input(NextExtRow(OpStackTablePermArg.master_ext_table_index())); + let perm_arg_remains_unchanged = + op_stack_table_perm_arg_next - op_stack_table_perm_arg_curr; + + vec![osp_remains_unchanged, perm_arg_remains_unchanged] } fn instruction_group_grow_op_stack_and_top_two_elements_unconstrained( @@ -1510,46 +1523,6 @@ impl ExtProcessorTable { .concat() } - /// Compute the randomly-weighted linear combination of the supplied stack - /// elements using the first `stack.len()` [challenges] as weights. - /// - /// # Panics - /// - /// Panics if the supplied stack is larger than [`OpStackElement::COUNT`]. - /// - /// [challenges]: StackWeight0 - fn compress_stack( - circuit_builder: &ConstraintCircuitBuilder, - stack: Vec>, - ) -> ConstraintCircuitMonad { - assert!(stack.len() <= OpStackElement::COUNT); - let challenges = [ - StackWeight0, - StackWeight1, - StackWeight2, - StackWeight3, - StackWeight4, - StackWeight5, - StackWeight6, - StackWeight7, - StackWeight8, - StackWeight9, - StackWeight10, - StackWeight11, - StackWeight12, - StackWeight13, - StackWeight14, - StackWeight15, - ] - .map(|ch| circuit_builder.challenge(ch)); - - challenges - .into_iter() - .zip(stack) - .map(|(weight, st)| weight * st) - .sum() - } - fn instruction_dup( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { @@ -1597,9 +1570,11 @@ impl ExtProcessorTable { let next_stack = stack.iter().map(|&st| next_row(st)).collect_vec(); let curr_stack_with_swapped_i = |i| stack_with_swapped_i(i).map(curr_row).collect_vec(); - let compress = |stack: Vec<_>| { + let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { assert_eq!(OpStackElement::COUNT, stack.len()); - Self::compress_stack(circuit_builder, stack) + let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i)); + let enumerated_stack = stack.into_iter().enumerate(); + enumerated_stack.map(|(i, st)| weight(i) * st).sum() }; let next_stack_is_current_stack_with_swapped_i = |i| { @@ -2999,9 +2974,11 @@ impl ExtProcessorTable { let new_stack = stack().dropping_back(n).map(next_row).collect_vec(); let old_stack_with_top_n_removed = stack().skip(n).map(curr_row).collect_vec(); - let compress = |stack: Vec<_>| { + let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { assert_eq!(OpStackElement::COUNT - n, stack.len()); - Self::compress_stack(circuit_builder, stack) + let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i)); + let enumerated_stack = stack.into_iter().enumerate(); + enumerated_stack.map(|(i, st)| weight(i) * st).sum() }; let compressed_new_stack = compress(new_stack); let compressed_old_stack = compress(old_stack_with_top_n_removed); @@ -3033,9 +3010,11 @@ impl ExtProcessorTable { let new_stack = stack().skip(n).map(next_row).collect_vec(); let old_stack_with_top_n_added = stack().map(curr_row).dropping_back(n).collect_vec(); - let compress = |stack: Vec<_>| { + let compress = |stack: Vec<_>| -> ConstraintCircuitMonad<_> { assert_eq!(OpStackElement::COUNT - n, stack.len()); - Self::compress_stack(circuit_builder, stack) + let weight = |i| circuit_builder.challenge(Self::stack_weight_by_index(i)); + let enumerated_stack = stack.into_iter().enumerate(); + enumerated_stack.map(|(i, st)| weight(i) * st).sum() }; let compressed_new_stack = compress(new_stack); let compressed_old_stack = compress(old_stack_with_top_n_added); @@ -3732,6 +3711,28 @@ impl ExtProcessorTable { + no_update_summand } + fn stack_weight_by_index(i: usize) -> ChallengeId { + match i { + 0 => StackWeight0, + 1 => StackWeight1, + 2 => StackWeight2, + 3 => StackWeight3, + 4 => StackWeight4, + 5 => StackWeight5, + 6 => StackWeight6, + 7 => StackWeight7, + 8 => StackWeight8, + 9 => StackWeight9, + 10 => StackWeight10, + 11 => StackWeight11, + 12 => StackWeight12, + 13 => StackWeight13, + 14 => StackWeight14, + 15 => StackWeight15, + i => panic!("Op Stack weight index must be in [0, 15], not {i}."), + } + } + pub fn transition_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { @@ -4661,6 +4662,21 @@ pub(crate) mod tests { let _ = ProcessorTable::op_stack_column_by_index(index); } + #[test] + fn can_get_stack_weight_for_in_range_index() { + for index in 0..OpStackElement::COUNT { + let _ = ExtProcessorTable::stack_weight_by_index(index); + } + } + + #[proptest] + #[should_panic(expected = "[0, 15]")] + fn cannot_get_stack_weight_for_out_of_range_index( + #[strategy(OpStackElement::COUNT..)] index: usize, + ) { + let _ = ExtProcessorTable::stack_weight_by_index(index); + } + #[proptest] fn constructing_factor_for_op_stack_table_running_product_never_panics( #[strategy(vec(arb(), BASE_WIDTH))] previous_row: Vec,