Skip to content

Commit

Permalink
feat: Implement rewrite IdentityInsertion (#474)
Browse files Browse the repository at this point in the history
Returns inserted identity node.


Closes #470

---------

Co-authored-by: Alec Edgington <alec.edgington@quantinuum.com>
  • Loading branch information
ss2165 and cqc-alec committed Aug 31, 2023
1 parent 1aaecb0 commit 9d02f94
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Rewrite operations on the HUGR - replacement, outlining, etc.

pub mod insert_identity;
pub mod outline_cfg;
pub mod simple_replace;
use std::mem;
Expand Down
148 changes: 148 additions & 0 deletions src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
//! Implementation of the `InsertIdentity` operation.

use crate::hugr::{HugrMut, Node};
use crate::ops::LeafOp;
use crate::types::EdgeKind;
use crate::{Direction, Hugr, HugrView, Port};

use super::Rewrite;

use itertools::Itertools;
use thiserror::Error;

/// Specification of a identity-insertion operation.
#[derive(Debug, Clone)]
pub struct IdentityInsertion {
/// The node following the identity to be inserted.
pub post_node: Node,
/// The port following the identity to be inserted.
pub post_port: Port,
}

impl IdentityInsertion {
/// Create a new [`IdentityInsertion`] specification.
pub fn new(post_node: Node, post_port: Port) -> Self {
Self {
post_node,
post_port,
}
}
}

/// Error from an [`IdentityInsertion`] operation.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum IdentityInsertionError {
/// Invalid node.
#[error("Node is invalid.")]
InvalidNode(),
/// Invalid port kind.
#[error("post_port has invalid kind {0:?}. Must be Value.")]
InvalidPortKind(Option<EdgeKind>),

/// Must be input port.
#[error("post_port is an output port, must be input.")]
PortIsOutput,
}

impl Rewrite for IdentityInsertion {
type Error = IdentityInsertionError;
/// The inserted node.
type ApplyResult = Node;
const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, _h: &Hugr) -> Result<(), IdentityInsertionError> {
/*
Assumptions:
1. Value kind inputs can only have one connection.
2. Node exists.
Conditions:
1. post_port is Value kind.
2. post_port is connected to a sibling of post_node.
3. post_port is input.
*/

unimplemented!()
}
fn apply(self, h: &mut Hugr) -> Result<Self::ApplyResult, IdentityInsertionError> {
if self.post_port.direction() != Direction::Incoming {
return Err(IdentityInsertionError::PortIsOutput);
}
let (pre_node, pre_port) = h
.linked_ports(self.post_node, self.post_port)
.exactly_one()
.expect("Value kind input can only have one connection.");

let kind = h.get_optype(self.post_node).port_kind(self.post_port);
let Some(EdgeKind::Value(ty)) = kind else {
return Err(IdentityInsertionError::InvalidPortKind(kind));
};

h.disconnect(self.post_node, self.post_port).unwrap();
let new_node = h.add_op(LeafOp::Noop { ty });
h.connect(pre_node, pre_port.index(), new_node, 0)
.expect("Should only fail if ports don't exist.");

h.connect(new_node, 0, self.post_node, self.post_port.index())
.expect("Should only fail if ports don't exist.");
Ok(new_node)
}
}

#[cfg(test)]
mod tests {
use rstest::rstest;

use super::super::simple_replace::test::dfg_hugr;
use super::*;
use crate::{
algorithm::nest_cfgs::test::build_conditional_in_loop_cfg, extension::prelude::QB_T,
ops::handle::NodeHandle, Hugr,
};

#[rstest]
fn correct_insertion(dfg_hugr: Hugr) {
let mut h = dfg_hugr;

assert_eq!(h.node_count(), 6);

let final_node = h
.input_neighbours(h.get_io(h.root()).unwrap()[1])
.next()
.unwrap();

let final_node_port = h.node_inputs(final_node).next().unwrap();

let rw = IdentityInsertion::new(final_node, final_node_port);

let noop_node = h.apply_rewrite(rw).unwrap();

assert_eq!(h.node_count(), 7);

let noop: LeafOp = h.get_optype(noop_node).clone().try_into().unwrap();

assert_eq!(noop, LeafOp::Noop { ty: QB_T });
}

#[test]
fn incorrect_insertion() {
let (mut h, _, tail) = build_conditional_in_loop_cfg(false).unwrap();

let final_node = tail.node();

let final_node_output = h.node_outputs(final_node).next().unwrap();
let rw = IdentityInsertion::new(final_node, final_node_output);
let apply_result = h.apply_rewrite(rw);
assert_eq!(apply_result, Err(IdentityInsertionError::PortIsOutput));

let final_node_input = h.node_inputs(final_node).next().unwrap();

let rw = IdentityInsertion::new(final_node, final_node_input);

let apply_result = h.apply_rewrite(rw);
assert_eq!(
apply_result,
Err(IdentityInsertionError::InvalidPortKind(Some(
EdgeKind::ControlFlow
)))
);
}
}
37 changes: 26 additions & 11 deletions src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ pub enum SimpleReplacementError {
}

#[cfg(test)]
mod test {
use std::collections::{HashMap, HashSet};

pub(in crate::hugr::rewrite) mod test {
use itertools::Itertools;
use portgraph::Direction;
use rstest::{fixture, rstest};
use std::collections::{HashMap, HashSet};

use crate::builder::{
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
Expand Down Expand Up @@ -257,6 +257,10 @@ mod test {
Ok(module_builder.finish_hugr()?)
}

#[fixture]
pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr {
make_hugr().unwrap()
}
/// Creates a hugr with a DFG root like the following:
/// ┌───┐
/// ┤ H ├──■──
Expand All @@ -274,6 +278,11 @@ mod test {
dfg_builder.finish_hugr_with_outputs(wire45.outputs())
}

#[fixture]
pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr {
make_dfg_hugr().unwrap()
}

/// Creates a hugr with a DFG root like the following:
/// ─────
/// ┌───┐
Expand All @@ -289,7 +298,12 @@ mod test {
dfg_builder.finish_hugr_with_outputs(wireoutvec)
}

#[test]
#[fixture]
pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr {
make_dfg_hugr2().unwrap()
}

#[rstest]
/// Replace the
/// ┌───┐
/// ──■──┤ H ├
Expand All @@ -308,8 +322,8 @@ mod test {
/// ├───┤┌─┴─┐
/// ┤ H ├┤ X ├
/// └───┘└───┘
fn test_simple_replacement() {
let mut h: Hugr = make_hugr().unwrap();
fn test_simple_replacement(simple_hugr: Hugr, dfg_hugr: Hugr) {
let mut h: Hugr = simple_hugr;
// 1. Find the DFG node for the inner circuit
let p: Node = h
.nodes()
Expand All @@ -323,7 +337,7 @@ mod test {
let (h_node_h0, h_node_h1) = h.output_neighbours(h_node_cx).collect_tuple().unwrap();
let s: HashSet<Node> = vec![h_node_cx, h_node_h0, h_node_h1].into_iter().collect();
// 3. Construct a new DFG-rooted hugr for the replacement
let n: Hugr = make_dfg_hugr().unwrap();
let n: Hugr = dfg_hugr;
// 4. Construct the input and output matchings
// 4.1. Locate the CX and its predecessor H's in n
let n_node_cx = n
Expand Down Expand Up @@ -376,7 +390,7 @@ mod test {
assert_eq!(h.validate(), Ok(()));
}

#[test]
#[rstest]
/// Replace the
///
/// ──■──
Expand All @@ -394,8 +408,9 @@ mod test {
/// ┌───┐
/// ┤ H ├
/// └───┘
fn test_simple_replacement_with_empty_wires() {
let mut h: Hugr = make_hugr().unwrap();
fn test_simple_replacement_with_empty_wires(simple_hugr: Hugr, dfg_hugr2: Hugr) {
let mut h: Hugr = simple_hugr;

// 1. Find the DFG node for the inner circuit
let p: Node = h
.nodes()
Expand All @@ -408,7 +423,7 @@ mod test {
.unwrap();
let s: HashSet<Node> = vec![h_node_cx].into_iter().collect();
// 3. Construct a new DFG-rooted hugr for the replacement
let n: Hugr = make_dfg_hugr2().unwrap();
let n: Hugr = dfg_hugr2;
// 4. Construct the input and output matchings
// 4.1. Locate the Output and its predecessor H in n
let n_node_output = n
Expand Down

0 comments on commit 9d02f94

Please sign in to comment.