diff --git a/src/hugr/rewrite.rs b/src/hugr/rewrite.rs index 415c39a18..e6b7a2031 100644 --- a/src/hugr/rewrite.rs +++ b/src/hugr/rewrite.rs @@ -1,5 +1,6 @@ //! Rewrite operations on the HUGR - replacement, outlining, etc. +pub mod insert_identity; pub mod outline_cfg; pub mod simple_replace; use std::mem; diff --git a/src/hugr/rewrite/insert_identity.rs b/src/hugr/rewrite/insert_identity.rs new file mode 100644 index 000000000..c2d177b29 --- /dev/null +++ b/src/hugr/rewrite/insert_identity.rs @@ -0,0 +1,148 @@ +//! Implementation of the `InsertIdentity` operation. + +use crate::hugr::{HugrMut, Node}; +use crate::ops::LeafOp; +use crate::types::EdgeKind; +use crate::{Direction, Hugr, HugrView, Port}; + +use super::Rewrite; + +use itertools::Itertools; +use thiserror::Error; + +/// Specification of a identity-insertion operation. +#[derive(Debug, Clone)] +pub struct IdentityInsertion { + /// The node following the identity to be inserted. + pub post_node: Node, + /// The port following the identity to be inserted. + pub post_port: Port, +} + +impl IdentityInsertion { + /// Create a new [`IdentityInsertion`] specification. + pub fn new(post_node: Node, post_port: Port) -> Self { + Self { + post_node, + post_port, + } + } +} + +/// Error from an [`IdentityInsertion`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum IdentityInsertionError { + /// Invalid node. + #[error("Node is invalid.")] + InvalidNode(), + /// Invalid port kind. + #[error("post_port has invalid kind {0:?}. Must be Value.")] + InvalidPortKind(Option), + + /// Must be input port. + #[error("post_port is an output port, must be input.")] + PortIsOutput, +} + +impl Rewrite for IdentityInsertion { + type Error = IdentityInsertionError; + /// The inserted node. + type ApplyResult = Node; + const UNCHANGED_ON_FAILURE: bool = true; + fn verify(&self, _h: &Hugr) -> Result<(), IdentityInsertionError> { + /* + Assumptions: + 1. Value kind inputs can only have one connection. + 2. Node exists. + Conditions: + 1. post_port is Value kind. + 2. post_port is connected to a sibling of post_node. + 3. post_port is input. + */ + + unimplemented!() + } + fn apply(self, h: &mut Hugr) -> Result { + if self.post_port.direction() != Direction::Incoming { + return Err(IdentityInsertionError::PortIsOutput); + } + let (pre_node, pre_port) = h + .linked_ports(self.post_node, self.post_port) + .exactly_one() + .expect("Value kind input can only have one connection."); + + let kind = h.get_optype(self.post_node).port_kind(self.post_port); + let Some(EdgeKind::Value(ty)) = kind else { + return Err(IdentityInsertionError::InvalidPortKind(kind)); + }; + + h.disconnect(self.post_node, self.post_port).unwrap(); + let new_node = h.add_op(LeafOp::Noop { ty }); + h.connect(pre_node, pre_port.index(), new_node, 0) + .expect("Should only fail if ports don't exist."); + + h.connect(new_node, 0, self.post_node, self.post_port.index()) + .expect("Should only fail if ports don't exist."); + Ok(new_node) + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::super::simple_replace::test::dfg_hugr; + use super::*; + use crate::{ + algorithm::nest_cfgs::test::build_conditional_in_loop_cfg, extension::prelude::QB_T, + ops::handle::NodeHandle, Hugr, + }; + + #[rstest] + fn correct_insertion(dfg_hugr: Hugr) { + let mut h = dfg_hugr; + + assert_eq!(h.node_count(), 6); + + let final_node = h + .input_neighbours(h.get_io(h.root()).unwrap()[1]) + .next() + .unwrap(); + + let final_node_port = h.node_inputs(final_node).next().unwrap(); + + let rw = IdentityInsertion::new(final_node, final_node_port); + + let noop_node = h.apply_rewrite(rw).unwrap(); + + assert_eq!(h.node_count(), 7); + + let noop: LeafOp = h.get_optype(noop_node).clone().try_into().unwrap(); + + assert_eq!(noop, LeafOp::Noop { ty: QB_T }); + } + + #[test] + fn incorrect_insertion() { + let (mut h, _, tail) = build_conditional_in_loop_cfg(false).unwrap(); + + let final_node = tail.node(); + + let final_node_output = h.node_outputs(final_node).next().unwrap(); + let rw = IdentityInsertion::new(final_node, final_node_output); + let apply_result = h.apply_rewrite(rw); + assert_eq!(apply_result, Err(IdentityInsertionError::PortIsOutput)); + + let final_node_input = h.node_inputs(final_node).next().unwrap(); + + let rw = IdentityInsertion::new(final_node, final_node_input); + + let apply_result = h.apply_rewrite(rw); + assert_eq!( + apply_result, + Err(IdentityInsertionError::InvalidPortKind(Some( + EdgeKind::ControlFlow + ))) + ); + } +} diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 966aa765a..6fe75c291 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -190,11 +190,11 @@ pub enum SimpleReplacementError { } #[cfg(test)] -mod test { - use std::collections::{HashMap, HashSet}; - +pub(in crate::hugr::rewrite) mod test { use itertools::Itertools; use portgraph::Direction; + use rstest::{fixture, rstest}; + use std::collections::{HashMap, HashSet}; use crate::builder::{ BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, @@ -257,6 +257,10 @@ mod test { Ok(module_builder.finish_hugr()?) } + #[fixture] + pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr { + make_hugr().unwrap() + } /// Creates a hugr with a DFG root like the following: /// ┌───┐ /// ┤ H ├──■── @@ -274,6 +278,11 @@ mod test { dfg_builder.finish_hugr_with_outputs(wire45.outputs()) } + #[fixture] + pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr { + make_dfg_hugr().unwrap() + } + /// Creates a hugr with a DFG root like the following: /// ───── /// ┌───┐ @@ -289,7 +298,12 @@ mod test { dfg_builder.finish_hugr_with_outputs(wireoutvec) } - #[test] + #[fixture] + pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr { + make_dfg_hugr2().unwrap() + } + + #[rstest] /// Replace the /// ┌───┐ /// ──■──┤ H ├ @@ -308,8 +322,8 @@ mod test { /// ├───┤┌─┴─┐ /// ┤ H ├┤ X ├ /// └───┘└───┘ - fn test_simple_replacement() { - let mut h: Hugr = make_hugr().unwrap(); + fn test_simple_replacement(simple_hugr: Hugr, dfg_hugr: Hugr) { + let mut h: Hugr = simple_hugr; // 1. Find the DFG node for the inner circuit let p: Node = h .nodes() @@ -323,7 +337,7 @@ mod test { let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap(); let s: HashSet = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect(); // 3. Construct a new DFG-rooted hugr for the replacement - let n: Hugr = make_dfg_hugr().unwrap(); + let n: Hugr = dfg_hugr; // 4. Construct the input and output matchings // 4.1. Locate the CX and its predecessor H's in n let n_node_cx = n @@ -376,7 +390,7 @@ mod test { assert_eq!(h.validate(), Ok(())); } - #[test] + #[rstest] /// Replace the /// /// ──■── @@ -394,8 +408,9 @@ mod test { /// ┌───┐ /// ┤ H ├ /// └───┘ - fn test_simple_replacement_with_empty_wires() { - let mut h: Hugr = make_hugr().unwrap(); + fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) { + let mut h: Hugr = simple_hugr; + // 1. Find the DFG node for the inner circuit let p: Node = h .nodes() @@ -408,7 +423,7 @@ mod test { .unwrap(); let s: HashSet = vec![h_node_cx].into_iter().collect(); // 3. Construct a new DFG-rooted hugr for the replacement - let n: Hugr = make_dfg_hugr2().unwrap(); + let n: Hugr = dfg_hugr2; // 4. Construct the input and output matchings // 4.1. Locate the Output and its predecessor H in n let n_node_output = n