-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add snippet for calculating a sum of a BFE list
- Loading branch information
1 parent
b0799d6
commit 9346159
Showing
5 changed files
with
372 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
[ | ||
{ | ||
"name": "tasm_list_unsafeimplu32_sum_bfe", | ||
"clock_cycle_count": 357, | ||
"hash_table_height": 54, | ||
"u32_table_height": 12, | ||
"case": "CommonCase" | ||
}, | ||
{ | ||
"name": "tasm_list_unsafeimplu32_sum_bfe", | ||
"clock_cycle_count": 2877, | ||
"hash_table_height": 54, | ||
"u32_table_height": 15, | ||
"case": "WorstCase" | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
use crate::data_type::DataType; | ||
use crate::library::Library; | ||
use crate::traits::basic_snippet::BasicSnippet; | ||
use triton_vm::prelude::*; | ||
|
||
use super::ListType; | ||
|
||
/// Calculate the sum of the `BFieldElement`s in a list | ||
struct SumOfBfes { | ||
list_type: ListType, | ||
} | ||
|
||
impl BasicSnippet for SumOfBfes { | ||
fn inputs(&self) -> Vec<(DataType, String)> { | ||
vec![( | ||
// For naming the input argument, I just follow what `Rust` calls this argument | ||
DataType::List(Box::new(DataType::Bfe)), | ||
"self".to_owned(), | ||
)] | ||
} | ||
|
||
fn outputs(&self) -> Vec<(DataType, String)> { | ||
vec![(DataType::Bfe, "sum".to_owned())] | ||
} | ||
|
||
fn entrypoint(&self) -> String { | ||
format!( | ||
"tasm_list_{}_sum_{}", | ||
self.list_type, | ||
DataType::Bfe.label_friendly_name() | ||
) | ||
} | ||
|
||
fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> { | ||
let entrypoint = self.entrypoint(); | ||
let accumulate_five_elements_loop_label = format!("{entrypoint}_acc_5_elements"); | ||
|
||
let accumulate_five_elements_loop = triton_asm!( | ||
// Invariant: _ *end_loop *element acc | ||
{accumulate_five_elements_loop_label}: | ||
|
||
dup 2 | ||
dup 2 | ||
eq | ||
skiz | ||
return | ||
// _ *end_loop *element acc | ||
|
||
dup 1 | ||
read_mem 5 | ||
// _ *end_loop *element acc [elements] (*element - 5) | ||
|
||
swap 7 | ||
pop 1 | ||
// _ *end_loop (*element - 5) acc [elements] | ||
// _ *end_loop *element' acc [elements] | ||
|
||
add | ||
add | ||
add | ||
add | ||
add | ||
// _ *end_loop *element' acc' | ||
|
||
recurse | ||
); | ||
|
||
let accumulate_one_element_loop_label = format!("{entrypoint}_acc_1_element"); | ||
let accumulate_one_element_loop = triton_asm!( | ||
// Invariant: _ *end_loop *element acc | ||
{accumulate_one_element_loop_label}: | ||
dup 2 | ||
dup 2 | ||
eq | ||
skiz | ||
return | ||
// _ *end_loop *element acc | ||
|
||
dup 1 | ||
read_mem 1 | ||
swap 3 | ||
pop 1 | ||
// _ *end_loop (*element - 1) acc element | ||
|
||
add | ||
// _ *end_loop *element' acc' | ||
|
||
recurse | ||
); | ||
|
||
let adjust_for_metadata = match self.list_type.metadata_size() { | ||
1 => triton_asm!(), | ||
2 => triton_asm!(push 1 add), | ||
n => panic!("Unhandled metadata size. Got: {n}"), | ||
}; | ||
|
||
let set_loop_1_end_condition = match self.list_type.metadata_size() { | ||
1 => triton_asm!(), | ||
2 => triton_asm!( | ||
// *list *next_element sum | ||
swap 2 | ||
push 1 | ||
add | ||
swap 2 | ||
// (*list + 1) *next_element sum | ||
), | ||
n => panic!("Unhandled metadata size. Got: {n}"), | ||
}; | ||
|
||
triton_asm!( | ||
{entrypoint}: | ||
// _ *list | ||
|
||
// Get pointer to last element | ||
dup 0 | ||
read_mem 1 | ||
// _ *list length (*list - 1) | ||
|
||
pop 1 | ||
// _ *list length | ||
|
||
dup 1 | ||
dup 1 | ||
add | ||
{&adjust_for_metadata} | ||
// _ *list length *last_element | ||
|
||
// Get pointer to *end_loop that is the loop termination condition | ||
|
||
push 5 | ||
dup 2 | ||
// _ *list length *last_element 5 length | ||
|
||
div_mod | ||
// _ *list length *last_element (length / 5) (length % 5) | ||
|
||
swap 1 | ||
pop 1 | ||
// _ *list length *last_element (length % 5) | ||
|
||
dup 3 | ||
add | ||
{&adjust_for_metadata} | ||
// _ *list length *last_element *element[length % 5] | ||
// _ *list length *last_element *end_loop | ||
|
||
swap 1 | ||
push 0 | ||
// _ *list length *end_loop *last_element 0 | ||
|
||
call {accumulate_five_elements_loop_label} | ||
// _ *list length *end_loop *next_element sum | ||
|
||
swap 1 | ||
// _ *list length *end_loop sum *next_element | ||
|
||
swap 3 | ||
// _ *list *next_element *end_loop sum length | ||
|
||
pop 1 | ||
// _ *list *next_element *end_loop sum | ||
|
||
swap 1 | ||
pop 1 | ||
// _ *list *next_element sum | ||
|
||
{&set_loop_1_end_condition} | ||
|
||
call {accumulate_one_element_loop_label} | ||
// _ *list *list sum | ||
|
||
swap 2 | ||
pop 2 | ||
// _ sum | ||
|
||
return | ||
|
||
{&accumulate_five_elements_loop} | ||
{&accumulate_one_element_loop} | ||
) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::collections::HashMap; | ||
|
||
use itertools::Itertools; | ||
use num_traits::Zero; | ||
use rand::rngs::StdRng; | ||
use rand::Rng; | ||
use rand::SeedableRng; | ||
|
||
use super::*; | ||
use crate::snippet_bencher::BenchmarkCase; | ||
use crate::traits::function::Function; | ||
use crate::traits::function::FunctionInitialState; | ||
use crate::traits::function::ShadowedFunction; | ||
use crate::traits::rust_shadow::RustShadow; | ||
|
||
impl Function for SumOfBfes { | ||
fn rust_shadow( | ||
&self, | ||
stack: &mut Vec<BFieldElement>, | ||
memory: &mut std::collections::HashMap<BFieldElement, BFieldElement>, | ||
) { | ||
const BFIELDELEMENT_SIZE: usize = 1; | ||
let list_pointer = stack.pop().unwrap(); | ||
let list = self | ||
.list_type | ||
.rust_shadowing_load_list_with_copy_element::<BFIELDELEMENT_SIZE>( | ||
list_pointer, | ||
memory, | ||
); | ||
|
||
let sum: BFieldElement = list.into_iter().map(|x| x[0]).sum(); | ||
stack.push(sum); | ||
} | ||
|
||
fn pseudorandom_initial_state( | ||
&self, | ||
seed: [u8; 32], | ||
bench_case: Option<crate::snippet_bencher::BenchmarkCase>, | ||
) -> FunctionInitialState { | ||
let mut rng: StdRng = SeedableRng::from_seed(seed); | ||
let list_pointer = BFieldElement::new(rng.gen()); | ||
let list_length = match bench_case { | ||
Some(BenchmarkCase::CommonCase) => 104, | ||
Some(BenchmarkCase::WorstCase) => 1004, | ||
None => rng.gen_range(0..200), | ||
}; | ||
self.prepare_state(list_pointer, list_length) | ||
} | ||
|
||
fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> { | ||
(0..13) | ||
.map(|len| self.prepare_state(BFieldElement::zero(), len)) | ||
.collect_vec() | ||
} | ||
} | ||
|
||
impl SumOfBfes { | ||
fn prepare_state( | ||
&self, | ||
list_pointer: BFieldElement, | ||
list_length: usize, | ||
) -> FunctionInitialState { | ||
let mut memory = HashMap::default(); | ||
self.list_type.rust_shadowing_insert_random_list( | ||
&DataType::Bfe, | ||
list_pointer, | ||
list_length, | ||
&mut memory, | ||
); | ||
|
||
let mut init_stack = self.init_stack_for_isolated_run(); | ||
init_stack.push(list_pointer); | ||
FunctionInitialState { | ||
stack: init_stack, | ||
memory, | ||
} | ||
} | ||
} | ||
|
||
#[test] | ||
fn sum_bfes_pbt_unsafe_list() { | ||
ShadowedFunction::new(SumOfBfes { | ||
list_type: ListType::Unsafe, | ||
}) | ||
.test() | ||
} | ||
|
||
#[test] | ||
fn sum_bfes_pbt_safe_list() { | ||
ShadowedFunction::new(SumOfBfes { | ||
list_type: ListType::Safe, | ||
}) | ||
.test() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod benches { | ||
use super::*; | ||
use crate::traits::function::ShadowedFunction; | ||
use crate::traits::rust_shadow::RustShadow; | ||
|
||
#[test] | ||
fn sum_bfes_bench_unsafe_lists() { | ||
ShadowedFunction::new(SumOfBfes { | ||
list_type: ListType::Unsafe, | ||
}) | ||
.bench(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.