Skip to content

Commit

Permalink
Add snippet for calculating a sum of a BFE list
Browse files Browse the repository at this point in the history
  • Loading branch information
Sword-Smith committed Feb 6, 2024
1 parent b0799d6 commit 9346159
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 1 deletion.
16 changes: 16 additions & 0 deletions tasm-lib/benchmarks/tasm_list_unsafeimplu32_sum_bfe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[
{
"name": "tasm_list_unsafeimplu32_sum_bfe",
"clock_cycle_count": 357,
"hash_table_height": 54,
"u32_table_height": 12,
"case": "CommonCase"
},
{
"name": "tasm_list_unsafeimplu32_sum_bfe",
"clock_cycle_count": 2877,
"hash_table_height": 54,
"u32_table_height": 15,
"case": "WorstCase"
}
]
16 changes: 15 additions & 1 deletion tasm-lib/src/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ use crate::list::unsafeimplu32::pop::UnsafePop;
use crate::list::unsafeimplu32::push::UnsafePush;
use crate::list::unsafeimplu32::set::UnsafeSet;
use crate::list::unsafeimplu32::set_length::UnsafeSetLength;
use crate::rust_shadowing_helper_functions;
use crate::rust_shadowing_helper_functions::{self, safe_list, unsafe_list};
use crate::traits::basic_snippet::BasicSnippet;

pub mod contiguous_list;
pub mod higher_order;
pub mod multiset_equality;
pub mod range;
pub mod safeimplu32;
pub mod sum_bfes;
pub mod swap_unchecked;
pub mod unsafeimplu32;

Expand Down Expand Up @@ -94,6 +95,19 @@ impl ListType {
}
}

pub fn rust_shadowing_load_list_with_copy_element<const ELEMENT_SIZE: usize>(
&self,
list_pointer: BFieldElement,
memory: &HashMap<BFieldElement, BFieldElement>,
) -> Vec<[BFieldElement; ELEMENT_SIZE]> {
match self {
ListType::Safe => safe_list::load_safe_list_with_copy_elements(list_pointer, memory),
ListType::Unsafe => {
unsafe_list::load_unsafe_list_with_copy_elements(list_pointer, memory)
}
}
}

/* Rust-shadowing helper functions */
pub fn rust_shadowing_get(
&self,
Expand Down
295 changes: 295 additions & 0 deletions tasm-lib/src/list/sum_bfes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
use crate::data_type::DataType;
use crate::library::Library;
use crate::traits::basic_snippet::BasicSnippet;
use triton_vm::prelude::*;

use super::ListType;

/// Calculate the sum of the `BFieldElement`s in a list
struct SumOfBfes {
list_type: ListType,
}

impl BasicSnippet for SumOfBfes {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![(
// For naming the input argument, I just follow what `Rust` calls this argument
DataType::List(Box::new(DataType::Bfe)),
"self".to_owned(),
)]
}

fn outputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::Bfe, "sum".to_owned())]
}

fn entrypoint(&self) -> String {
format!(
"tasm_list_{}_sum_{}",
self.list_type,
DataType::Bfe.label_friendly_name()
)
}

fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
let entrypoint = self.entrypoint();
let accumulate_five_elements_loop_label = format!("{entrypoint}_acc_5_elements");

let accumulate_five_elements_loop = triton_asm!(
// Invariant: _ *end_loop *element acc
{accumulate_five_elements_loop_label}:

dup 2
dup 2
eq
skiz
return
// _ *end_loop *element acc

dup 1
read_mem 5
// _ *end_loop *element acc [elements] (*element - 5)

swap 7
pop 1
// _ *end_loop (*element - 5) acc [elements]
// _ *end_loop *element' acc [elements]

add
add
add
add
add
// _ *end_loop *element' acc'

recurse
);

let accumulate_one_element_loop_label = format!("{entrypoint}_acc_1_element");
let accumulate_one_element_loop = triton_asm!(
// Invariant: _ *end_loop *element acc
{accumulate_one_element_loop_label}:
dup 2
dup 2
eq
skiz
return
// _ *end_loop *element acc

dup 1
read_mem 1
swap 3
pop 1
// _ *end_loop (*element - 1) acc element

add
// _ *end_loop *element' acc'

recurse
);

let adjust_for_metadata = match self.list_type.metadata_size() {
1 => triton_asm!(),
2 => triton_asm!(push 1 add),
n => panic!("Unhandled metadata size. Got: {n}"),
};

let set_loop_1_end_condition = match self.list_type.metadata_size() {
1 => triton_asm!(),
2 => triton_asm!(
// *list *next_element sum
swap 2
push 1
add
swap 2
// (*list + 1) *next_element sum
),
n => panic!("Unhandled metadata size. Got: {n}"),
};

triton_asm!(
{entrypoint}:
// _ *list

// Get pointer to last element
dup 0
read_mem 1
// _ *list length (*list - 1)

pop 1
// _ *list length

dup 1
dup 1
add
{&adjust_for_metadata}
// _ *list length *last_element

// Get pointer to *end_loop that is the loop termination condition

push 5
dup 2
// _ *list length *last_element 5 length

div_mod
// _ *list length *last_element (length / 5) (length % 5)

swap 1
pop 1
// _ *list length *last_element (length % 5)

dup 3
add
{&adjust_for_metadata}
// _ *list length *last_element *element[length % 5]
// _ *list length *last_element *end_loop

swap 1
push 0
// _ *list length *end_loop *last_element 0

call {accumulate_five_elements_loop_label}
// _ *list length *end_loop *next_element sum

swap 1
// _ *list length *end_loop sum *next_element

swap 3
// _ *list *next_element *end_loop sum length

pop 1
// _ *list *next_element *end_loop sum

swap 1
pop 1
// _ *list *next_element sum

{&set_loop_1_end_condition}

call {accumulate_one_element_loop_label}
// _ *list *list sum

swap 2
pop 2
// _ sum

return

{&accumulate_five_elements_loop}
{&accumulate_one_element_loop}
)
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use itertools::Itertools;
use num_traits::Zero;
use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;

use super::*;
use crate::snippet_bencher::BenchmarkCase;
use crate::traits::function::Function;
use crate::traits::function::FunctionInitialState;
use crate::traits::function::ShadowedFunction;
use crate::traits::rust_shadow::RustShadow;

impl Function for SumOfBfes {
fn rust_shadow(
&self,
stack: &mut Vec<BFieldElement>,
memory: &mut std::collections::HashMap<BFieldElement, BFieldElement>,
) {
const BFIELDELEMENT_SIZE: usize = 1;
let list_pointer = stack.pop().unwrap();
let list = self
.list_type
.rust_shadowing_load_list_with_copy_element::<BFIELDELEMENT_SIZE>(
list_pointer,
memory,
);

let sum: BFieldElement = list.into_iter().map(|x| x[0]).sum();
stack.push(sum);
}

fn pseudorandom_initial_state(
&self,
seed: [u8; 32],
bench_case: Option<crate::snippet_bencher::BenchmarkCase>,
) -> FunctionInitialState {
let mut rng: StdRng = SeedableRng::from_seed(seed);
let list_pointer = BFieldElement::new(rng.gen());
let list_length = match bench_case {
Some(BenchmarkCase::CommonCase) => 104,
Some(BenchmarkCase::WorstCase) => 1004,
None => rng.gen_range(0..200),
};
self.prepare_state(list_pointer, list_length)
}

fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
(0..13)
.map(|len| self.prepare_state(BFieldElement::zero(), len))
.collect_vec()
}
}

impl SumOfBfes {
fn prepare_state(
&self,
list_pointer: BFieldElement,
list_length: usize,
) -> FunctionInitialState {
let mut memory = HashMap::default();
self.list_type.rust_shadowing_insert_random_list(
&DataType::Bfe,
list_pointer,
list_length,
&mut memory,
);

let mut init_stack = self.init_stack_for_isolated_run();
init_stack.push(list_pointer);
FunctionInitialState {
stack: init_stack,
memory,
}
}
}

#[test]
fn sum_bfes_pbt_unsafe_list() {
ShadowedFunction::new(SumOfBfes {
list_type: ListType::Unsafe,
})
.test()
}

#[test]
fn sum_bfes_pbt_safe_list() {
ShadowedFunction::new(SumOfBfes {
list_type: ListType::Safe,
})
.test()
}
}

#[cfg(test)]
mod benches {
use super::*;
use crate::traits::function::ShadowedFunction;
use crate::traits::rust_shadow::RustShadow;

#[test]
fn sum_bfes_bench_unsafe_lists() {
ShadowedFunction::new(SumOfBfes {
list_type: ListType::Unsafe,
})
.bench();
}
}
23 changes: 23 additions & 0 deletions tasm-lib/src/rust_shadowing_helper_functions/safe_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,29 @@ use num::Zero;
use triton_vm::prelude::*;

use crate::data_type::DataType;
use crate::list::ListType;

/// Load a list from memory. Elements must be of `Copy` type.
pub fn load_safe_list_with_copy_elements<const ELEMENT_SIZE: usize>(
list_pointer: BFieldElement,
memory: &HashMap<BFieldElement, BFieldElement>,
) -> Vec<[BFieldElement; ELEMENT_SIZE]> {
let list_length: usize = memory[&list_pointer].value().try_into().unwrap();

let mut element_pointer =
list_pointer + BFieldElement::new(ListType::Safe.metadata_size() as u64);

let mut ret = Vec::with_capacity(list_length);
for i in 0..list_length {
ret.push([BFieldElement::zero(); ELEMENT_SIZE]);
for j in 0..ELEMENT_SIZE {
ret[i][j] = memory[&element_pointer];
element_pointer.increment();
}
}

ret
}

pub fn safe_list_insert<T: BFieldCodec>(
list_pointer: BFieldElement,
Expand Down
Loading

0 comments on commit 9346159

Please sign in to comment.