Skip to content

Commit

Permalink
fix: Use correct types in “higher order” snippets
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Sep 18, 2024
1 parent 311bab3 commit 1c556be
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 49 deletions.
10 changes: 3 additions & 7 deletions tasm-lib/src/list/higher_order/all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,9 @@ impl All {

impl BasicSnippet for All {
fn inputs(&self) -> Vec<(DataType, String)> {
let input_type = match &self.f {
InnerFunction::BasicSnippet(basic_snippet) => {
DataType::List(Box::new(basic_snippet.inputs()[0].0.clone()))
}
_ => DataType::VoidPointer,
};
vec![(input_type, "*input_list".to_string())]
let element_type = self.f.domain();
let list_type = DataType::List(Box::new(element_type));
vec![(list_type, "*input_list".to_string())]
}

fn outputs(&self) -> Vec<(DataType, String)> {
Expand Down
24 changes: 8 additions & 16 deletions tasm-lib/src/list/higher_order/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,21 @@ use super::inner_function::InnerFunction;

/// Filters a given list for elements that satisfy a predicate. A new
/// list is created, containing only those elements that satisfy the
/// predicate. The predicate must be given as an InnerFunction.
/// predicate. The predicate must be given as an [`InnerFunction`].
pub struct Filter {
pub f: InnerFunction,
}

impl BasicSnippet for Filter {
fn inputs(&self) -> Vec<(DataType, String)> {
let list_type = match &self.f {
InnerFunction::BasicSnippet(basic_snippet) => {
DataType::List(Box::new(basic_snippet.inputs()[0].0.clone()))
}
_ => DataType::VoidPointer,
};
let element_type = self.f.domain();
let list_type = DataType::List(Box::new(element_type));
vec![(list_type, "*input_list".to_string())]
}

fn outputs(&self) -> Vec<(DataType, String)> {
let list_type = match &self.f {
InnerFunction::BasicSnippet(basic_snippet) => {
DataType::List(Box::new(basic_snippet.inputs()[0].0.clone()))
}
_ => DataType::VoidPointer,
};
let element_type = self.f.range();
let list_type = DataType::List(Box::new(element_type));
vec![(list_type, "*output_list".to_string())]
}

Expand Down Expand Up @@ -90,7 +82,7 @@ impl BasicSnippet for Filter {
// If function was supplied as raw instructions, we need to append the inner function to the function
// body. Otherwise, `library` handles the imports.
let maybe_inner_function_body_raw = match &self.f {
InnerFunction::RawCode(rc) => rc.function.iter().map(|x| x.to_string()).join("\n"),
InnerFunction::RawCode(rc) => rc.function.iter().join("\n"),
InnerFunction::DeprecatedSnippet(_) => String::default(),
InnerFunction::NoFunctionBody(_) => todo!(),
InnerFunction::BasicSnippet(_) => String::default(),
Expand All @@ -114,7 +106,7 @@ impl BasicSnippet for Filter {
call {main_loop} // _ *input_list *output_list input_len input_len output_len

swap 2 pop 2 // _ *input_list *output_list output_len
call {set_length} // _input_list *output_list
call {set_length} // _ *input_list *output_list

swap 1 // _ *output_list *input_list
pop 1 // _ *output_list
Expand All @@ -123,7 +115,7 @@ impl BasicSnippet for Filter {
// INVARIANT: _ *input_list *output_list input_len input_index output_index
{main_loop}:
// test return condition
dup 1 // _ *input_list *output_list input_len input_index output_index input_index
dup 1 // _ *input_list *output_list input_len input_index output_index input_index
dup 3 eq // _ *input_list *output_list input_len input_index output_index input_index==input_len

skiz return
Expand Down
10 changes: 4 additions & 6 deletions tasm-lib/src/list/higher_order/inner_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,9 @@ impl InnerFunction {
}
InnerFunction::NoFunctionBody(f) => f.input_type.clone(),
InnerFunction::BasicSnippet(bs) => {
let [ref input] = bs.inputs()[..] else {
let [(ref input, _)] = bs.inputs()[..] else {
panic!("{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}");
};
let (input, _) = input;
input.clone()
}
}
Expand All @@ -143,10 +142,9 @@ impl InnerFunction {
}
InnerFunction::NoFunctionBody(lnat) => lnat.output_type.clone(),
InnerFunction::BasicSnippet(bs) => {
let [ref output] = bs.outputs()[..] else {
let [(ref output, _)] = bs.outputs()[..] else {
panic!("{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}");
};
let (output, _) = output;
output.clone()
}
}
Expand All @@ -163,8 +161,8 @@ impl InnerFunction {
}

/// Run the VM for on a given stack and memory to observe how it manipulates the
/// stack. This is a helper function for [`apply`](apply), which in some cases just
/// grabs the inner function's code and then needs a VM to apply it.
/// stack. This is a helper function for [`apply`](Self::apply), which in some cases
/// just grabs the inner function's code and then needs a VM to apply it.
fn run_vm(
instructions: &[LabelledInstruction],
stack: &mut Vec<BFieldElement>,
Expand Down
22 changes: 2 additions & 20 deletions tasm-lib/src/list/higher_order/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ use super::inner_function::InnerFunction;

const INNER_FN_INCORRECT_NUM_INPUTS: &str = "Inner function in `map` only works with *one* \
input. Use a tuple as a workaround.";
const INNER_FN_INCORRECT_NUM_OUTPUTS: &str = "Inner function in `map` only works with *one* \
output. Use a tuple as a workaround.";

/// Applies a given function to every element of a list, and collects the new elements
/// into a new list.
Expand All @@ -42,29 +40,13 @@ impl Map {

impl BasicSnippet for Map {
fn inputs(&self) -> Vec<(DataType, String)> {
let element_type = if let InnerFunction::BasicSnippet(snippet) = &self.f {
let [(ref element_type, _)] = snippet.inputs()[..] else {
panic!("{INNER_FN_INCORRECT_NUM_INPUTS}");
};
element_type.to_owned()
} else {
DataType::VoidPointer
};

let element_type = self.f.domain();
let list_type = DataType::List(Box::new(element_type));
vec![(list_type, "*input_list".to_string())]
}

fn outputs(&self) -> Vec<(DataType, String)> {
let element_type = if let InnerFunction::BasicSnippet(snippet) = &self.f {
let [(ref element_type, _)] = snippet.outputs()[..] else {
panic!("{INNER_FN_INCORRECT_NUM_OUTPUTS}");
};
element_type.to_owned()
} else {
DataType::VoidPointer
};

let element_type = self.f.range();
let list_type = DataType::List(Box::new(element_type));
vec![(list_type, "*output_list".to_string())]
}
Expand Down

0 comments on commit 1c556be

Please sign in to comment.