Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 'Replace' rewrite returns node map #929

Merged
merged 4 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 63 additions & 66 deletions hugr/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ impl Replacement {
impl Rewrite for Replacement {
type Error = ReplaceError;

type ApplyResult = ();
/// Map from Node in replacement to corresponding Node in the result Hugr
type ApplyResult = HashMap<Node, Node>;

const UNCHANGED_ON_FAILURE: bool = false;

Expand Down Expand Up @@ -270,7 +271,7 @@ impl Rewrite for Replacement {
Ok(())
}

fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> {
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
let parent = self.check_parent(h)?;
// Calculate removed nodes here. (Does not include transfers, so enumerates only
// nodes we are going to remove, individually, anyway; so no *asymptotic* speed penalty)
Expand Down Expand Up @@ -324,7 +325,7 @@ impl Rewrite for Replacement {

// 7. Remove remaining nodes
to_remove.into_iter().for_each(|n| h.remove_node(n));
Ok(())
Ok(node_map)
}

fn invalidation_set(&self) -> impl Iterator<Item = Node> {
Expand Down Expand Up @@ -680,87 +681,91 @@ mod test {
let case2 = case2.finish_with_outputs(baz_dfg.outputs())?.node();
let cond = cond.finish_sub_container()?;
let h = h.finish_hugr_with_outputs(cond.outputs(), &PRELUDE_REGISTRY)?;
let verify_apply = |r: Replacement| {
let verify_res = r.verify(&h);
let apply_res = r.apply(&mut h.clone());
assert_eq!(verify_res, apply_res);
verify_res
};

// Note wrong root type here - we'll replace children of the *Conditional*
let mut rep1 = Hugr::new(h.root_type().clone());
let r1 = rep1.add_node_with_parent(
rep1.root(),
let mut r_hugr = Hugr::new(NodeType::new_open(h.get_optype(cond.node()).clone()));
let r1 = r_hugr.add_node_with_parent(
r_hugr.root(),
Case {
signature: utou.clone(),
},
);
let r2 = rep1.add_node_with_parent(
rep1.root(),
let r2 = r_hugr.add_node_with_parent(
r_hugr.root(),
Case {
signature: utou.clone(),
},
);
let mut r = Replacement {
let rep: Replacement = Replacement {
removal: vec![case1, case2],
replacement: rep1,
replacement: r_hugr,
adoptions: HashMap::from_iter([(r1, case1), (r2, baz_dfg.node())]),
mu_inp: vec![],
mu_out: vec![],
mu_new: vec![],
};
assert_eq!(h.get_parent(baz.node()), Some(baz_dfg.node()));
rep.verify(&h).unwrap();
{
let mut target = h.clone();
let node_map = rep.clone().apply(&mut target)?;
let new_case2 = *node_map.get(&r2).unwrap();
assert_eq!(target.get_parent(baz.node()), Some(new_case2));
}

// Test some bad Replacements (using variations of the `replacement` Hugr).
let check_same_errors = |r: Replacement| {
let verify_res = r.verify(&h).unwrap_err();
let apply_res = r.apply(&mut h.clone()).unwrap_err();
assert_eq!(verify_res, apply_res);
apply_res
};
// Root node type needs to be that of common parent of the removed nodes:
let mut rep2 = rep.clone();
rep2.replacement
.replace_op(rep2.replacement.root(), h.root_type().clone())?;
assert_eq!(
verify_apply(r.clone()),
Err(ReplaceError::WrongRootNodeTag {
check_same_errors(rep2),
ReplaceError::WrongRootNodeTag {
removed: OpTag::Conditional,
replacement: OpTag::Dfg
})
}
);
r.replacement.replace_op(
r.replacement.root(),
NodeType::new_open(h.get_optype(cond.node()).clone()),
)?;
assert_eq!(verify_apply(r.clone()), Ok(()));

// And test some bad Replacements (using the same `replacement` Hugr).
// First, removed nodes...
// Removed nodes...
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
removal: vec![h.root()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::CantReplaceRoot)
ReplaceError::CantReplaceRoot
);
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
removal: vec![case1, baz_dfg.node()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::MultipleParents(vec![cond.node(), case2]))
ReplaceError::MultipleParents(vec![cond.node(), case2])
);
// Adoptions...
assert_eq!(
verify_apply(Replacement {
adoptions: HashMap::from([(r1, case1), (r.replacement.root(), case2)]),
..r.clone()
check_same_errors(Replacement {
adoptions: HashMap::from([(r1, case1), (rep.replacement.root(), case2)]),
..rep.clone()
}),
Err(ReplaceError::InvalidAdoptingParent(r.replacement.root()))
ReplaceError::InvalidAdoptingParent(rep.replacement.root())
);
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
adoptions: HashMap::from_iter([(r1, case1), (r2, case1)]),
..r.clone()
..rep.clone()
}),
Err(ReplaceError::AdopteesNotSeparateDescendants(vec![case1]))
ReplaceError::AdopteesNotSeparateDescendants(vec![case1])
);
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
adoptions: HashMap::from_iter([(r1, case2), (r2, baz_dfg.node())]),
..r.clone()
..rep.clone()
}),
Err(ReplaceError::AdopteesNotSeparateDescendants(vec![
baz_dfg.node()
]))
ReplaceError::AdopteesNotSeparateDescendants(vec![baz_dfg.node()])
);
// Edges....
let edge_from_removed = NewEdgeSpec {
Expand All @@ -769,43 +774,35 @@ mod test {
kind: NewEdgeKind::Order,
};
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
mu_inp: vec![edge_from_removed.clone()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::BadEdgeSpec(
Direction::Outgoing,
WhichHugr::Retained,
edge_from_removed
))
ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Retained, edge_from_removed)
);
let bad_out_edge = NewEdgeSpec {
src: h.nodes().max().unwrap(), // not valid in replacement
tgt: cond.node(),
kind: NewEdgeKind::Order,
};
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
mu_out: vec![bad_out_edge.clone()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::BadEdgeSpec(
Direction::Outgoing,
WhichHugr::Replacement,
bad_out_edge
))
ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, bad_out_edge)
);
let bad_order_edge = NewEdgeSpec {
src: cond.node(),
tgt: h.get_io(h.root()).unwrap()[1],
kind: NewEdgeKind::ControlFlow { src_pos: 0.into() },
};
assert_matches!(
verify_apply(Replacement {
check_same_errors(Replacement {
mu_new: vec![bad_order_edge.clone()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::BadEdgeKind(_, e)) => e == bad_order_edge
ReplaceError::BadEdgeKind(_, e) => e == bad_order_edge
);
let op = OutgoingPort::from(0);
let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap();
Expand All @@ -818,11 +815,11 @@ mod test {
},
};
assert_eq!(
verify_apply(Replacement {
check_same_errors(Replacement {
mu_out: vec![new_out_edge.clone()],
..r.clone()
..rep.clone()
}),
Err(ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge))
ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge)
);
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,6 @@ pub(in crate::hugr::rewrite) mod test {
}

fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
h.apply_rewrite(to_replace(h, rw)).unwrap()
h.apply_rewrite(to_replace(h, rw)).unwrap();
}
}
Loading