diff --git a/hugr/src/hugr/rewrite/replace.rs b/hugr/src/hugr/rewrite/replace.rs index a5f891817..d2647367b 100644 --- a/hugr/src/hugr/rewrite/replace.rs +++ b/hugr/src/hugr/rewrite/replace.rs @@ -213,7 +213,8 @@ impl Replacement { impl Rewrite for Replacement { type Error = ReplaceError; - type ApplyResult = (); + /// Map from Node in replacement to corresponding Node in the result Hugr + type ApplyResult = HashMap; const UNCHANGED_ON_FAILURE: bool = false; @@ -270,7 +271,7 @@ impl Rewrite for Replacement { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + fn apply(self, h: &mut impl HugrMut) -> Result { let parent = self.check_parent(h)?; // Calculate removed nodes here. (Does not include transfers, so enumerates only // nodes we are going to remove, individually, anyway; so no *asymptotic* speed penalty) @@ -324,7 +325,7 @@ impl Rewrite for Replacement { // 7. Remove remaining nodes to_remove.into_iter().for_each(|n| h.remove_node(n)); - Ok(()) + Ok(node_map) } fn invalidation_set(&self) -> impl Iterator { @@ -680,87 +681,91 @@ mod test { let case2 = case2.finish_with_outputs(baz_dfg.outputs())?.node(); let cond = cond.finish_sub_container()?; let h = h.finish_hugr_with_outputs(cond.outputs(), &PRELUDE_REGISTRY)?; - let verify_apply = |r: Replacement| { - let verify_res = r.verify(&h); - let apply_res = r.apply(&mut h.clone()); - assert_eq!(verify_res, apply_res); - verify_res - }; - // Note wrong root type here - we'll replace children of the *Conditional* - let mut rep1 = Hugr::new(h.root_type().clone()); - let r1 = rep1.add_node_with_parent( - rep1.root(), + let mut r_hugr = Hugr::new(NodeType::new_open(h.get_optype(cond.node()).clone())); + let r1 = r_hugr.add_node_with_parent( + r_hugr.root(), Case { signature: utou.clone(), }, ); - let r2 = rep1.add_node_with_parent( - rep1.root(), + let r2 = r_hugr.add_node_with_parent( + r_hugr.root(), Case { signature: utou.clone(), }, ); - let mut r = Replacement { + let rep: Replacement = Replacement { removal: vec![case1, case2], - replacement: rep1, + replacement: r_hugr, adoptions: HashMap::from_iter([(r1, case1), (r2, baz_dfg.node())]), mu_inp: vec![], mu_out: vec![], mu_new: vec![], }; + assert_eq!(h.get_parent(baz.node()), Some(baz_dfg.node())); + rep.verify(&h).unwrap(); + { + let mut target = h.clone(); + let node_map = rep.clone().apply(&mut target)?; + let new_case2 = *node_map.get(&r2).unwrap(); + assert_eq!(target.get_parent(baz.node()), Some(new_case2)); + } + + // Test some bad Replacements (using variations of the `replacement` Hugr). + let check_same_errors = |r: Replacement| { + let verify_res = r.verify(&h).unwrap_err(); + let apply_res = r.apply(&mut h.clone()).unwrap_err(); + assert_eq!(verify_res, apply_res); + apply_res + }; + // Root node type needs to be that of common parent of the removed nodes: + let mut rep2 = rep.clone(); + rep2.replacement + .replace_op(rep2.replacement.root(), h.root_type().clone())?; assert_eq!( - verify_apply(r.clone()), - Err(ReplaceError::WrongRootNodeTag { + check_same_errors(rep2), + ReplaceError::WrongRootNodeTag { removed: OpTag::Conditional, replacement: OpTag::Dfg - }) + } ); - r.replacement.replace_op( - r.replacement.root(), - NodeType::new_open(h.get_optype(cond.node()).clone()), - )?; - assert_eq!(verify_apply(r.clone()), Ok(())); - - // And test some bad Replacements (using the same `replacement` Hugr). - // First, removed nodes... + // Removed nodes... assert_eq!( - verify_apply(Replacement { + check_same_errors(Replacement { removal: vec![h.root()], - ..r.clone() + ..rep.clone() }), - Err(ReplaceError::CantReplaceRoot) + ReplaceError::CantReplaceRoot ); assert_eq!( - verify_apply(Replacement { + check_same_errors(Replacement { removal: vec![case1, baz_dfg.node()], - ..r.clone() + ..rep.clone() }), - Err(ReplaceError::MultipleParents(vec![cond.node(), case2])) + ReplaceError::MultipleParents(vec![cond.node(), case2]) ); // Adoptions... assert_eq!( - verify_apply(Replacement { - adoptions: HashMap::from([(r1, case1), (r.replacement.root(), case2)]), - ..r.clone() + check_same_errors(Replacement { + adoptions: HashMap::from([(r1, case1), (rep.replacement.root(), case2)]), + ..rep.clone() }), - Err(ReplaceError::InvalidAdoptingParent(r.replacement.root())) + ReplaceError::InvalidAdoptingParent(rep.replacement.root()) ); assert_eq!( - verify_apply(Replacement { + check_same_errors(Replacement { adoptions: HashMap::from_iter([(r1, case1), (r2, case1)]), - ..r.clone() + ..rep.clone() }), - Err(ReplaceError::AdopteesNotSeparateDescendants(vec![case1])) + ReplaceError::AdopteesNotSeparateDescendants(vec![case1]) ); assert_eq!( - verify_apply(Replacement { + check_same_errors(Replacement { adoptions: HashMap::from_iter([(r1, case2), (r2, baz_dfg.node())]), - ..r.clone() + ..rep.clone() }), - Err(ReplaceError::AdopteesNotSeparateDescendants(vec![ - baz_dfg.node() - ])) + ReplaceError::AdopteesNotSeparateDescendants(vec![baz_dfg.node()]) ); // Edges.... let edge_from_removed = NewEdgeSpec { @@ -769,15 +774,11 @@ mod test { kind: NewEdgeKind::Order, }; assert_eq!( - verify_apply(Replacement { + check_same_errors(Replacement { mu_inp: vec![edge_from_removed.clone()], - ..r.clone() + ..rep.clone() }), - Err(ReplaceError::BadEdgeSpec( - Direction::Outgoing, - WhichHugr::Retained, - edge_from_removed - )) + ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Retained, edge_from_removed) ); let bad_out_edge = NewEdgeSpec { src: h.nodes().max().unwrap(), // not valid in replacement @@ -785,15 +786,11 @@ mod test { kind: NewEdgeKind::Order, }; assert_eq!( - verify_apply(Replacement { + check_same_errors(Replacement { mu_out: vec![bad_out_edge.clone()], - ..r.clone() + ..rep.clone() }), - Err(ReplaceError::BadEdgeSpec( - Direction::Outgoing, - WhichHugr::Replacement, - bad_out_edge - )) + ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, bad_out_edge) ); let bad_order_edge = NewEdgeSpec { src: cond.node(), @@ -801,11 +798,11 @@ mod test { kind: NewEdgeKind::ControlFlow { src_pos: 0.into() }, }; assert_matches!( - verify_apply(Replacement { + check_same_errors(Replacement { mu_new: vec![bad_order_edge.clone()], - ..r.clone() + ..rep.clone() }), - Err(ReplaceError::BadEdgeKind(_, e)) => e == bad_order_edge + ReplaceError::BadEdgeKind(_, e) => e == bad_order_edge ); let op = OutgoingPort::from(0); let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap(); @@ -818,11 +815,11 @@ mod test { }, }; assert_eq!( - verify_apply(Replacement { + check_same_errors(Replacement { mu_out: vec![new_out_edge.clone()], - ..r.clone() + ..rep.clone() }), - Err(ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge)) + ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge) ); Ok(()) } diff --git a/hugr/src/hugr/rewrite/simple_replace.rs b/hugr/src/hugr/rewrite/simple_replace.rs index 4fd64795b..69dfd218d 100644 --- a/hugr/src/hugr/rewrite/simple_replace.rs +++ b/hugr/src/hugr/rewrite/simple_replace.rs @@ -635,6 +635,6 @@ pub(in crate::hugr::rewrite) mod test { } fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) { - h.apply_rewrite(to_replace(h, rw)).unwrap() + h.apply_rewrite(to_replace(h, rw)).unwrap(); } }