diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index dfbdbd84c2..a48cceea08 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -3,10 +3,7 @@ use super::{ Instruction, }; use alloc::{borrow::Borrow, string::ToString, vec::Vec}; -use vm_core::{ - mast::{MastNode, MastNodeId}, - AdviceInjector, AssemblyOp, Operation, -}; +use vm_core::{mast::MastNodeId, AdviceInjector, AssemblyOp, Operation}; // BASIC BLOCK BUILDER // ================================================================================================ @@ -134,10 +131,9 @@ impl BasicBlockBuilder { let ops = self.ops.drain(..).collect(); let decorators = self.decorators.drain(..).collect(); - let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators); - let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node); + let block_node_id = mast_forest_builder.ensure_block(ops, Some(decorators)); - Some(basic_block_node_id) + Some(block_node_id) } else if !self.decorators.is_empty() { // this is a bug in the assembler. we shouldn't have decorators added without their // associated operations diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 9894a82092..aa14943997 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -6,7 +6,7 @@ use crate::{ }; use smallvec::SmallVec; -use vm_core::mast::{MastForest, MastNode, MastNodeId}; +use vm_core::mast::{MastForest, MastNodeId}; /// Procedure Invocation impl Assembler { @@ -94,8 +94,7 @@ impl Assembler { mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { // If the MAST root called isn't known to us, make it an external // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node) + mast_forest_builder.ensure_external(mast_root) }) } InvokeKind::Call => { @@ -103,25 +102,20 @@ impl Assembler { mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { // If the MAST root called isn't known to us, make it an external // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node) + mast_forest_builder.ensure_external(mast_root) }); - let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(call_node) + mast_forest_builder.ensure_call(callee_id) } InvokeKind::SysCall => { let callee_id = mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { // If the MAST root called isn't known to us, make it an external // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node) + mast_forest_builder.ensure_external(mast_root) }); - let syscall_node = - MastNode::new_syscall(callee_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(syscall_node) + mast_forest_builder.ensure_syscall(callee_id) } } }; @@ -134,7 +128,7 @@ impl Assembler { &self, mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { - let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn); + let dyn_node_id = mast_forest_builder.ensure_dynexec(); Ok(Some(dyn_node_id)) } @@ -145,10 +139,8 @@ impl Assembler { mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let dyn_call_node_id = { - let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn); - let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest()); - - mast_forest_builder.ensure_node(dyn_call_node) + let dyn_node_id = mast_forest_builder.ensure_dynexec(); + mast_forest_builder.ensure_call(dyn_node_id) }; Ok(Some(dyn_call_node_id)) diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index c4b1aa67f9..67045a59de 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -1,9 +1,11 @@ use core::ops::Index; use alloc::collections::BTreeMap; +use alloc::vec::Vec; use vm_core::{ crypto::hash::RpoDigest, mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + DecoratorList, Operation, }; /// Builder for a [`MastForest`]. @@ -44,7 +46,7 @@ impl MastForestBuilder { /// If a [`MastNode`] which is equal to the current node was previously added, the previously /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal /// [`MastNode`]s have equal [`MastNodeId`]s. - pub fn ensure_node(&mut self, node: impl Into) -> MastNodeId { + fn ensure_node(&mut self, node: impl Into) -> MastNodeId { let node = node.into(); let node_digest = node.digest(); @@ -59,6 +61,55 @@ impl MastForestBuilder { } } + /// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_block( + &mut self, + operations: Vec, + decorators: Option, + ) -> MastNodeId { + match decorators { + Some(decorators) => { + self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators)) + } + None => self.ensure_node(MastNode::new_basic_block(operations)), + } + } + + /// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_join(&mut self, left_child: MastNodeId, right_child: MastNodeId) -> MastNodeId { + self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest)) + } + + /// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_split(&mut self, if_branch: MastNodeId, else_branch: MastNodeId) -> MastNodeId { + self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest)) + } + + /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_loop(&mut self, body: MastNodeId) -> MastNodeId { + self.ensure_node(MastNode::new_loop(body, &self.mast_forest)) + } + + /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_call(&mut self, callee: MastNodeId) -> MastNodeId { + self.ensure_node(MastNode::new_call(callee, &self.mast_forest)) + } + + /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_syscall(&mut self, callee: MastNodeId) -> MastNodeId { + self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest)) + } + + /// Adds a dynexec node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_dynexec(&mut self) -> MastNodeId { + self.ensure_node(MastNode::new_dynexec()) + } + + /// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_external(&mut self, mast_root: RpoDigest) -> MastNodeId { + self.ensure_node(MastNode::new_external(mast_root)) + } + /// Marks the given [`MastNodeId`] as being the root of a procedure. pub fn make_root(&mut self, new_root_id: MastNodeId) { self.mast_forest.make_root(new_root_id) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 6235b78f46..38834202f7 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -11,7 +11,7 @@ use crate::{ use alloc::{boxed::Box, sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; use vm_core::{ - mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId, MerkleTreeNode}, Decorator, DecoratorList, Kernel, Operation, Program, }; @@ -790,12 +790,7 @@ impl Assembler { let else_blk = self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?; - let split_node_id = { - let split_node = - MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest()); - - mast_forest_builder.ensure_node(split_node) - }; + let split_node_id = mast_forest_builder.ensure_split(then_blk, else_blk); mast_node_ids.push(split_node_id); } @@ -824,11 +819,7 @@ impl Assembler { let loop_body_node_id = self.compile_body(body.iter(), context, None, mast_forest_builder)?; - let loop_node_id = { - let loop_node = - MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(loop_node) - }; + let loop_node_id = mast_forest_builder.ensure_loop(loop_body_node_id); mast_node_ids.push(loop_node_id); } } @@ -839,7 +830,7 @@ impl Assembler { } Ok(if mast_node_ids.is_empty() { - mast_forest_builder.ensure_node(vec![Operation::Noop]) + mast_forest_builder.ensure_block(vec![Operation::Noop], None) } else { combine_mast_node_ids(mast_node_ids, mast_forest_builder) }) @@ -899,8 +890,7 @@ fn combine_mast_node_ids( while let (Some(left), Some(right)) = (source_mast_node_iter.next(), source_mast_node_iter.next()) { - let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest()); - let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node); + let join_mast_node_id = mast_forest_builder.ensure_join(left, right); mast_node_ids.push(join_mast_node_id); } diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index a1935841a9..62ea6dfbef 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -81,11 +81,10 @@ fn nested_blocks() { // contains the MAST nodes for the kernel after a call to // `Assembler::with_kernel_from_module()`. let syscall_foo_node_id = { - let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(vec![Operation::Add]); + let kernel_foo_node_id = + expected_mast_forest_builder.ensure_block(vec![Operation::Add], None); - let syscall_node = - MastNode::new_syscall(kernel_foo_node_id, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(syscall_node) + expected_mast_forest_builder.ensure_syscall(kernel_foo_node_id) }; let program = r#" @@ -129,74 +128,47 @@ fn nested_blocks() { let exec_bar_node_id = { // bar procedure let basic_block_1_id = - expected_mast_forest_builder.ensure_node(vec![Operation::Push(17_u32.into())]); + expected_mast_forest_builder.ensure_block(vec![Operation::Push(17_u32.into())], None); // Basic block representing the `foo` procedure let basic_block_2_id = - expected_mast_forest_builder.ensure_node(vec![Operation::Push(19_u32.into())]); - - let join_node = MastNode::new_join( - basic_block_1_id, - basic_block_2_id, - expected_mast_forest_builder.forest(), - ); - expected_mast_forest_builder.ensure_node(join_node) - }; + expected_mast_forest_builder.ensure_block(vec![Operation::Push(19_u32.into())], None); - let exec_foo_bar_baz_node_id = { - // basic block representing foo::bar.baz procedure - expected_mast_forest_builder.ensure_node(vec![Operation::Push(29_u32.into())]) + expected_mast_forest_builder.ensure_join(basic_block_1_id, basic_block_2_id) }; - let before = expected_mast_forest_builder.ensure_node(vec![Operation::Push(2u32.into())]); + // basic block representing foo::bar.baz procedure + let exec_foo_bar_baz_node_id = + expected_mast_forest_builder.ensure_block(vec![Operation::Push(29_u32.into())], None); - let r#true1 = expected_mast_forest_builder.ensure_node(vec![Operation::Push(3u32.into())]); - let r#false1 = expected_mast_forest_builder.ensure_node(vec![Operation::Push(5u32.into())]); - let r#if1 = { - let r#if_node = - MastNode::new_split(r#true1, r#false1, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(r#if_node) - }; + let before = + expected_mast_forest_builder.ensure_block(vec![Operation::Push(2u32.into())], None); - let r#true3 = expected_mast_forest_builder.ensure_node(vec![Operation::Push(7u32.into())]); - let r#false3 = expected_mast_forest_builder.ensure_node(vec![Operation::Push(11u32.into())]); - let r#true2 = { - let r#if_node = - MastNode::new_split(r#true3, r#false3, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(r#if_node) - }; + let r#true1 = + expected_mast_forest_builder.ensure_block(vec![Operation::Push(3u32.into())], None); + let r#false1 = + expected_mast_forest_builder.ensure_block(vec![Operation::Push(5u32.into())], None); + let r#if1 = expected_mast_forest_builder.ensure_split(r#true1, r#false1); + + let r#true3 = + expected_mast_forest_builder.ensure_block(vec![Operation::Push(7u32.into())], None); + let r#false3 = + expected_mast_forest_builder.ensure_block(vec![Operation::Push(11u32.into())], None); + let r#true2 = expected_mast_forest_builder.ensure_split(r#true3, r#false3); let r#while = { let push_basic_block_id = - expected_mast_forest_builder.ensure_node(vec![Operation::Push(23u32.into())]); - let body_node_id = { - let body_node = MastNode::new_join( - exec_bar_node_id, - push_basic_block_id, - expected_mast_forest_builder.forest(), - ); - - expected_mast_forest_builder.ensure_node(body_node) - }; - - let loop_node = MastNode::new_loop(body_node_id, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(loop_node) + expected_mast_forest_builder.ensure_block(vec![Operation::Push(23u32.into())], None); + let body_node_id = + expected_mast_forest_builder.ensure_join(exec_bar_node_id, push_basic_block_id); + + expected_mast_forest_builder.ensure_loop(body_node_id) }; let push_13_basic_block_id = - expected_mast_forest_builder.ensure_node(vec![Operation::Push(13u32.into())]); - - let r#false2 = { - let node = MastNode::new_join( - push_13_basic_block_id, - r#while, - expected_mast_forest_builder.forest(), - ); - expected_mast_forest_builder.ensure_node(node) - }; - let nested = { - let node = MastNode::new_split(r#true2, r#false2, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(node) - }; + expected_mast_forest_builder.ensure_block(vec![Operation::Push(13u32.into())], None); + + let r#false2 = expected_mast_forest_builder.ensure_join(push_13_basic_block_id, r#while); + let nested = expected_mast_forest_builder.ensure_split(r#true2, r#false2); let combined_node_id = combine_mast_node_ids( vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id],