Skip to content

Commit

Permalink
feat(Map): Check correctness of dyn length types
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Sep 26, 2024
1 parent c963bb5 commit 4c4edb7
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions tasm-lib/src/list/higher_order/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ use crate::traits::function::FunctionInitialState;

const INNER_FN_INCORRECT_NUM_INPUTS: &str = "Inner function in `map` only works with *one* input. \
Use a tuple as a workaround.";
const INNER_FN_INCORRECT_INPUT_DYN_LEN: &str = "An input type of dynamic length to `map`s inner \
function must be a tuple of form `(bfe, _)`.";

/// Applies a given function to every element of a list, and collects the new
/// elements into a new list.
Expand Down Expand Up @@ -94,14 +96,24 @@ impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
///
/// - if the input type has [static length] _and_ takes up
/// [`OpStackElement::COUNT`] or more words
/// - if the input type has dynamic length and is _anything but_ a tuple
/// `(_, `[`BFieldElement`][bfe]`)`
/// - if the output type takes up [`OpStackElement::COUNT`]` - 1` or more words
/// - if the output type does not have a [static length][len]
///
/// [len]: BFieldCodec::static_length
/// [bfe]: DataType::Bfe
pub fn new(f: InnerFunction) -> Self {
if let Some(input_len) = f.domain().static_length() {
// need instruction `place {input_type.stack_size()}`
assert!(input_len < OpStackElement::COUNT);
} else {
let DataType::Tuple(tuple) = f.domain() else {
panic!("{INNER_FN_INCORRECT_INPUT_DYN_LEN}");
};
let [_, DataType::Bfe] = tuple[..] else {
panic!("{INNER_FN_INCORRECT_INPUT_DYN_LEN}");
};
}

// need instruction `pick {output_type.stack_size() + 1}`
Expand Down Expand Up @@ -522,7 +534,6 @@ impl<const NUM_INPUT_LISTS: usize> Function for ChainMap<NUM_INPUT_LISTS> {

#[cfg(test)]
mod tests {

use itertools::Itertools;
use num_traits::Zero;
use proptest_arbitrary_interop::arb;
Expand Down Expand Up @@ -928,9 +939,10 @@ mod tests {
#[test]
fn mapping_over_dynamic_length_items_works() {
let f = || {
let list_type = DataType::List(Box::new(DataType::Bfe));
InnerFunction::RawCode(RawCode::new(
triton_asm!(just_forty_twos: pop 2 push 42 return),
DataType::List(Box::new(DataType::Bfe)),
DataType::Tuple(vec![list_type, DataType::Bfe]),
DataType::Bfe,
))
};
Expand All @@ -942,9 +954,10 @@ mod tests {
#[test]
fn mapping_over_list_of_lists_writing_their_lengths_works() {
let f = || {
let list_type = DataType::List(Box::new(DataType::Bfe));
InnerFunction::RawCode(RawCode::new(
triton_asm!(write_list_length: pop 1 read_mem 1 pop 1 return),
DataType::List(Box::new(DataType::Bfe)),
DataType::Tuple(vec![list_type, DataType::Bfe]),
DataType::Bfe,
))
};
Expand All @@ -971,9 +984,10 @@ mod benches {

#[test]
fn map_with_dyn_items_benchmark() {
let list_type = DataType::List(Box::new(DataType::Bfe));
let f = InnerFunction::RawCode(RawCode::new(
triton_asm!(dyn_length_elements: pop 2 push 42 return),
DataType::List(Box::new(DataType::Bfe)),
DataType::Tuple(vec![list_type, DataType::Bfe]),
DataType::Bfe,
));
ShadowedFunction::new(Map::new(f)).bench();
Expand Down

0 comments on commit 4c4edb7

Please sign in to comment.