Skip to content

Commit

Permalink
fmt + clippy
Browse files Browse the repository at this point in the history
  • Loading branch information
prozacchiwawa committed Apr 25, 2024
1 parent c6a423e commit 88f72e8
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 58 deletions.
37 changes: 27 additions & 10 deletions src/run_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,33 @@ const OP_COST: Cost = 1;
// exceeded
const STACK_SIZE_LIMIT: usize = 20000000;

/// Tell whether to call the post eval function or not, giving a reference id
/// for the computation to pick up.
pub enum PreEvalResult {
CallPostEval(usize),
Done
Done,
}

#[cfg(feature = "pre-eval")]
/// Implementing this trait allows an object to be notified of clvm operations
/// being performed as they happen.
pub trait PreEval {
fn pre_eval(&mut self, allocator: &mut Allocator, sexp: NodePtr, args: NodePtr) -> Result<PreEvalResult, EvalErr>;
fn post_eval(&mut self, _allocator: &mut Allocator, _pass: usize, _result: Option<NodePtr>) -> Result<(), EvalErr> {
/// pre_eval is called before the operator is run, giving sexp (the operation
/// to run) and args (the environment).
fn pre_eval(
&mut self,
allocator: &mut Allocator,
sexp: NodePtr,
args: NodePtr,
) -> Result<PreEvalResult, EvalErr>;
/// post_eval is called after the operation was performed. When the clvm
/// operation resulted in an error, result is None.
fn post_eval(
&mut self,
_allocator: &mut Allocator,
_pass: usize,
_result: Option<NodePtr>,
) -> Result<(), EvalErr> {
Ok(())
}
}
Expand Down Expand Up @@ -122,7 +140,7 @@ fn augment_cost_errors(r: Result<Cost, EvalErr>, max_cost: NodePtr) -> Result<Co
})
}

impl<'a, 'inner, D: Dialect> RunProgramContext<'a, D> {
impl<'a, D: Dialect> RunProgramContext<'a, D> {
#[cfg(feature = "counters")]
#[inline(always)]
fn account_val_push(&mut self) {
Expand Down Expand Up @@ -276,7 +294,9 @@ impl<'a, 'inner, D: Dialect> RunProgramContext<'a, D> {
fn eval_pair(&mut self, program: NodePtr, env: NodePtr) -> Result<Cost, EvalErr> {
#[cfg(feature = "pre-eval")]
if let Some(pre_eval) = &mut self.pre_eval {
if let PreEvalResult::CallPostEval(pass) = pre_eval.pre_eval(&mut self.allocator, program, env)? {
if let PreEvalResult::CallPostEval(pass) =
pre_eval.pre_eval(self.allocator, program, env)?
{
self.posteval_stack.push(pass);
self.op_stack.push(Operation::PostEval);
}
Expand Down Expand Up @@ -492,10 +512,7 @@ impl<'a, 'inner, D: Dialect> RunProgramContext<'a, D> {
cost += match op {
Operation::Apply => {
let apply_op_res = self.apply_op(cost, effective_max_cost - cost);
augment_cost_errors(
apply_op_res,
max_cost_ptr,
)?
augment_cost_errors(apply_op_res, max_cost_ptr)?
}
Operation::ExitGuard => self.exit_guard(cost)?,
Operation::Cons => self.cons_op()?,
Expand All @@ -505,7 +522,7 @@ impl<'a, 'inner, D: Dialect> RunProgramContext<'a, D> {
if let Some(pre_eval) = &mut self.pre_eval {
let f = self.posteval_stack.pop().unwrap();
let peek: Option<NodePtr> = self.val_stack.last().copied();
pre_eval.post_eval(&mut self.allocator, f, peek)?;
pre_eval.post_eval(self.allocator, f, peek)?;
}
0
}
Expand Down
90 changes: 42 additions & 48 deletions src/test_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,16 +364,47 @@ struct EvalFTracker {
#[cfg(feature = "pre-eval")]
use crate::chia_dialect::{ChiaDialect, NO_UNKNOWN_OPS};
#[cfg(feature = "pre-eval")]
use crate::run_program::run_program_with_pre_eval;
#[cfg(feature = "pre-eval")]
use std::cell::RefCell;
use crate::run_program::{run_program_with_pre_eval, PreEval, PreEvalResult};
#[cfg(feature = "pre-eval")]
use std::collections::HashSet;

// Allows move closures to tear off a reference and move it. // Allows interior
// mutability inside Fn traits.
#[cfg(feature = "pre-eval")]
use std::rc::Rc;
#[derive(Default)]
struct PreEvalTracking {
table: HashMap<usize, EvalFTracker>,
}

#[cfg(feature = "pre-eval")]
impl PreEval for PreEvalTracking {
fn pre_eval(
&mut self,
_allocator: &mut Allocator,
prog: NodePtr,
args: NodePtr,
) -> Result<PreEvalResult, EvalErr> {
let tracking_key = self.table.len();
self.table.insert(
tracking_key,
EvalFTracker {
prog,
args,
outcome: None,
},
);
Ok(PreEvalResult::CallPostEval(tracking_key))
}
fn post_eval(
&mut self,
_allocator: &mut Allocator,
pass: usize,
outcome: Option<NodePtr>,
) -> Result<(), EvalErr> {
if let Some(entry) = self.table.get_mut(&pass) {
entry.outcome = outcome;
}
Ok(())
}
}

// Ensure pre_eval_f and post_eval_f are working as expected.
#[cfg(feature = "pre-eval")]
Expand Down Expand Up @@ -406,51 +437,15 @@ fn test_pre_eval_and_post_eval() {
let a_args = allocator.new_pair(f_quoted, a_tail).unwrap();
let program = allocator.new_pair(a2, a_args).unwrap();

let tracking = Rc::new(RefCell::new(HashMap::new()));
let pre_eval_tracking = tracking.clone();
let pre_eval_f: Box<
dyn Fn(
&mut Allocator,
NodePtr,
NodePtr,
) -> Result<Option<Box<(dyn Fn(Option<NodePtr>))>>, EvalErr>,
> = Box::new(move |_allocator, prog, args| {
let tracking_key = pre_eval_tracking.borrow().len();
// Ensure lifetime of mutable borrow is contained.
// It must end before the lifetime of the following closure.
{
let mut tracking_mutable = pre_eval_tracking.borrow_mut();
tracking_mutable.insert(
tracking_key,
EvalFTracker {
prog,
args,
outcome: None,
},
);
}
let post_eval_tracking = pre_eval_tracking.clone();
let post_eval_f: Box<dyn Fn(Option<NodePtr>)> = Box::new(move |outcome| {
let mut tracking_mutable = post_eval_tracking.borrow_mut();
tracking_mutable.insert(
tracking_key,
EvalFTracker {
prog,
args,
outcome,
},
);
});
Ok(Some(post_eval_f))
});
let mut tracking = PreEvalTracking::default();

let result = run_program_with_pre_eval(
&mut allocator,
&ChiaDialect::new(NO_UNKNOWN_OPS),
program,
NodePtr::NIL,
COST_LIMIT,
Some(pre_eval_f),
Some(&mut tracking),
)
.unwrap();

Expand Down Expand Up @@ -478,8 +473,7 @@ fn test_pre_eval_and_post_eval() {
desired_outcomes.push((program, NodePtr::NIL, a99));

let mut found_outcomes = HashSet::new();
let tracking_examine = tracking.borrow();
for (_, v) in tracking_examine.iter() {
for (_, v) in tracking.table.iter() {
let found = desired_outcomes.iter().position(|(p, a, o)| {
node_eq(&allocator, *p, v.prog)
&& node_eq(&allocator, *a, v.args)
Expand All @@ -489,6 +483,6 @@ fn test_pre_eval_and_post_eval() {
assert!(found.is_some());
}

assert_eq!(tracking_examine.len(), desired_outcomes.len());
assert_eq!(tracking_examine.len(), found_outcomes.len());
assert_eq!(tracking.table.len(), desired_outcomes.len());
assert_eq!(tracking.table.len(), found_outcomes.len());
}

0 comments on commit 88f72e8

Please sign in to comment.