Skip to content

Commit

Permalink
feat: 'Replace' rewrite returns node map (#929)
Browse files Browse the repository at this point in the history
This is a preliminary step towards using merging basic blocks, where the
Replace rewrite needs to be used as part of a larger transformation -
this makes it possible to use compositionally.

Along the way, refactor the tests a bit (changing `verify_apply` to
`check_same_error` and doing all the bad/`Err`-returning cases after the
success case).
  • Loading branch information
acl-cqc authored and aborgna-q committed Apr 23, 2024
1 parent 6664ced commit 99ed2ba
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 67 deletions.
129 changes: 63 additions & 66 deletions hugr/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ impl Replacement {
impl Rewrite for Replacement {
type Error = ReplaceError;

type ApplyResult = ();
/// Map from Node in replacement to corresponding Node in the result Hugr
type ApplyResult = HashMap<Node, Node>;

const UNCHANGED_ON_FAILURE: bool = false;

Expand Down Expand Up @@ -270,7 +271,7 @@ impl Rewrite for Replacement {
Ok(())
}

fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> {
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
let parent = self.check_parent(h)?;
// Calculate removed nodes here. (Does not include transfers, so enumerates only
// nodes we are going to remove, individually, anyway; so no *asymptotic* speed penalty)
Expand Down Expand Up @@ -324,7 +325,7 @@ impl Rewrite for Replacement {

// 7. Remove remaining nodes
to_remove.into_iter().for_each(|n| h.remove_node(n));
Ok(())
Ok(node_map)
}

fn invalidation_set(&self) -> impl Iterator<Item = Node> {
Expand Down Expand Up @@ -680,87 +681,91 @@ mod test {
let case2 = case2.finish_with_outputs(baz_dfg.outputs())?.node();
let cond = cond.finish_sub_container()?;
let h = h.finish_hugr_with_outputs(cond.outputs(), &PRELUDE_REGISTRY)?;
let verify_apply = |r: Replacement| {
let verify_res = r.verify(&h);
let apply_res = r.apply(&mut h.clone());
assert_eq!(verify_res, apply_res);
verify_res
};

// Note wrong root type here - we'll replace children of the *Conditional*
let mut rep1 = Hugr::new(h.root_type().clone());
let r1 = rep1.add_node_with_parent(
rep1.root(),
let mut r_hugr = Hugr::new(NodeType::new_open(h.get_optype(cond.node()).clone()));
let r1 = r_hugr.add_node_with_parent(
r_hugr.root(),
Case {
signature: utou.clone(),
},
);
let r2 = rep1.add_node_with_parent(
rep1.root(),
let r2 = r_hugr.add_node_with_parent(
r_hugr.root(),
Case {
signature: utou.clone(),
},
);
let mut r = Replacement {
let rep: Replacement = Replacement {
removal: vec![case1, case2],
replacement: rep1,
replacement: r_hugr,
adoptions: HashMap::from_iter([(r1, case1), (r2, baz_dfg.node())]),
mu_inp: vec![],
mu_out: vec![],
mu_new: vec![],
};
assert_eq!(h.get_parent(baz.node()), Some(baz_dfg.node()));
rep.verify(&h).unwrap();
{
let mut target = h.clone();
let node_map = rep.clone().apply(&mut target)?;
let new_case2 = *node_map.get(&r2).unwrap();
assert_eq!(target.get_parent(baz.node()), Some(new_case2));
}

// Test some bad Replacements (using variations of the `replacement` Hugr).
let check_same_errors = |r: Replacement| {
let verify_res = r.verify(&h).unwrap_err();
let apply_res = r.apply(&mut h.clone()).unwrap_err();
assert_eq!(verify_res, apply_res);
apply_res
};
// Root node type needs to be that of common parent of the removed nodes:
let mut rep2 = rep.clone();
rep2.replacement
.replace_op(rep2.replacement.root(), h.root_type().clone())?;
assert_eq!(
verify_apply(r.clone()),
Err(ReplaceError::WrongRootNodeTag {
check_same_errors(rep2),
ReplaceError::WrongRootNodeTag {
removed: OpTag::Conditional,
replacement: OpTag::Dfg
})
}
);
r.replacement.replace_op(
r.replacement.root(),
NodeType::new_open(h.get_optype(cond.node()).clone()),
)?;
assert_eq!(verify_apply(r.clone()), Ok(()));

// And test some bad Replacements (using the same `replacement` Hugr).
// First, removed nodes...
// Removed nodes...
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
removal: vec![h.root()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::CantReplaceRoot)
ReplaceError::CantReplaceRoot
);
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
removal: vec![case1, baz_dfg.node()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::MultipleParents(vec![cond.node(), case2]))
ReplaceError::MultipleParents(vec![cond.node(), case2])
);
// Adoptions...
assert_eq!(
verify_apply(Replacement {
adoptions: HashMap::from([(r1, case1), (r.replacement.root(), case2)]),
..r.clone()
check_same_errors(Replacement {
adoptions: HashMap::from([(r1, case1), (rep.replacement.root(), case2)]),
..rep.clone()
}),
Err(ReplaceError::InvalidAdoptingParent(r.replacement.root()))
ReplaceError::InvalidAdoptingParent(rep.replacement.root())
);
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
adoptions: HashMap::from_iter([(r1, case1), (r2, case1)]),
..r.clone()
..rep.clone()
}),
Err(ReplaceError::AdopteesNotSeparateDescendants(vec![case1]))
ReplaceError::AdopteesNotSeparateDescendants(vec![case1])
);
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
adoptions: HashMap::from_iter([(r1, case2), (r2, baz_dfg.node())]),
..r.clone()
..rep.clone()
}),
Err(ReplaceError::AdopteesNotSeparateDescendants(vec![
baz_dfg.node()
]))
ReplaceError::AdopteesNotSeparateDescendants(vec![baz_dfg.node()])
);
// Edges....
let edge_from_removed = NewEdgeSpec {
Expand All @@ -769,43 +774,35 @@ mod test {
kind: NewEdgeKind::Order,
};
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
mu_inp: vec![edge_from_removed.clone()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::BadEdgeSpec(
Direction::Outgoing,
WhichHugr::Retained,
edge_from_removed
))
ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Retained, edge_from_removed)
);
let bad_out_edge = NewEdgeSpec {
src: h.nodes().max().unwrap(), // not valid in replacement
tgt: cond.node(),
kind: NewEdgeKind::Order,
};
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
mu_out: vec![bad_out_edge.clone()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::BadEdgeSpec(
Direction::Outgoing,
WhichHugr::Replacement,
bad_out_edge
))
ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, bad_out_edge)
);
let bad_order_edge = NewEdgeSpec {
src: cond.node(),
tgt: h.get_io(h.root()).unwrap()[1],
kind: NewEdgeKind::ControlFlow { src_pos: 0.into() },
};
assert_matches!(
verify_apply(Replacement {
check_same_errors(Replacement {
mu_new: vec![bad_order_edge.clone()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::BadEdgeKind(_, e)) => e == bad_order_edge
ReplaceError::BadEdgeKind(_, e) => e == bad_order_edge
);
let op = OutgoingPort::from(0);
let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap();
Expand All @@ -818,11 +815,11 @@ mod test {
},
};
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
mu_out: vec![new_out_edge.clone()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge))
ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge)
);
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,6 @@ pub(in crate::hugr::rewrite) mod test {
}

fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
h.apply_rewrite(to_replace(h, rw)).unwrap()
h.apply_rewrite(to_replace(h, rw)).unwrap();
}
}

0 comments on commit 99ed2ba

Please sign in to comment.