From 6dbdbb0bf4269ca96a9c1b9dd907ca275e6d522c Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 12 Jun 2024 16:57:47 +0100 Subject: [PATCH 1/4] test: Failing test for the simple_replace bug --- hugr-core/src/hugr/rewrite/simple_replace.rs | 99 ++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index a70c275aa..383951ac5 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -209,9 +209,11 @@ pub(in crate::hugr::rewrite) mod test { use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; use crate::ops::dataflow::DataflowOpTrait; + use crate::ops::handle::NodeHandle; use crate::ops::OpTag; use crate::ops::OpTrait; use crate::std_extensions::logic::test::and_op; + use crate::std_extensions::logic::NotOp; use crate::type_row; use crate::types::{FunctionType, Type}; use crate::utils::test_quantum_extension::{cx_gate, h_gate}; @@ -309,6 +311,43 @@ pub(in crate::hugr::rewrite) mod test { make_dfg_hugr2().unwrap() } + /// A hugr with a DFG root mapping BOOL_T to (BOOL_T, BOOL_T) + /// ┌─────────┐ + /// ┌────┤ (1) NOT ├── + /// ┌─────────┐ │ └─────────┘ + /// ─┤ (0) NOT ├───┤ + /// └─────────┘ │ ┌─────────┐ + /// └────┤ (2) NOT ├── + /// └─────────┘ + /// This can be replaced with an empty hugr coping the input to both outputs. + /// + /// Returns the hugr and the nodes of the NOT gates, in order. + #[fixture] + pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec) { + fn build() -> Result<(Hugr, Vec), BuildError> { + let mut dfg_builder = DFGBuilder::new(FunctionType::new( + type_row![BOOL_T], + type_row![BOOL_T, BOOL_T], + ))?; + let [b] = dfg_builder.input_wires_arr(); + + let not_inp = dfg_builder.add_dataflow_op(NotOp, vec![b])?; + let [b] = not_inp.outputs_arr(); + + let not_0 = dfg_builder.add_dataflow_op(NotOp, vec![b])?; + let [b0] = not_0.outputs_arr(); + let not_1 = dfg_builder.add_dataflow_op(NotOp, vec![b])?; + let [b1] = not_1.outputs_arr(); + + Ok(( + dfg_builder.finish_prelude_hugr_with_outputs([b0, b1])?, + vec![not_inp.node(), not_0.node(), not_1.node()], + )) + } + + build().unwrap() + } + #[rstest] /// Replace the /// ┌───┐ @@ -572,6 +611,66 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(h.node_count(), orig.node_count()); } + #[rstest] + fn test_copy_inputs( + dfg_hugr_copy_bools: (Hugr, Vec), + ) -> Result<(), Box> { + let (mut hugr, nodes) = dfg_hugr_copy_bools; + let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap(); + + println!("{}", hugr.mermaid_string()); + + let [_input, output] = hugr.get_io(hugr.root()).unwrap(); + + let replacement = { + let b = DFGBuilder::new(FunctionType::new( + type_row![BOOL_T], + type_row![BOOL_T, BOOL_T], + ))?; + let [w] = b.input_wires_arr(); + b.finish_prelude_hugr_with_outputs([w, w])? + }; + let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap(); + + let subgraph = + SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)?; + // A map from (target ports of edges from the Input node of `replacement`) to (target ports of + // edges from nodes not in `removal` to nodes in `removal`). + let nu_inp = [ + ( + (repl_output, IncomingPort::from(0)), + (input_not, IncomingPort::from(0)), + ), + ( + (repl_output, IncomingPort::from(1)), + (input_not, IncomingPort::from(1)), + ), + ] + .into_iter() + .collect(); + // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to + // (input ports of the Output node of `replacement`). + let nu_out = [ + ((output, IncomingPort::from(0)), IncomingPort::from(0)), + ((output, IncomingPort::from(1)), IncomingPort::from(1)), + ] + .into_iter() + .collect(); + + let rewrite = SimpleReplacement { + subgraph, + replacement, + nu_inp, + nu_out, + }; + rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); + + assert_eq!(hugr.update_validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(hugr.node_count(), 3); + + Ok(()) + } + use crate::hugr::rewrite::replace::Replacement; fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement { use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec}; From dffff01d7d76cd959822bfe67ed8b81cf105bef4 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 12 Jun 2024 17:40:43 +0100 Subject: [PATCH 2/4] fix: Panic in SimpleReplace --- hugr-core/src/hugr/rewrite/simple_replace.rs | 27 ++++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 383951ac5..c758f7a71 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; -use crate::{Hugr, IncomingPort, Node}; +use crate::{Hugr, IncomingPort, Node, OutgoingPort}; use thiserror::Error; /// Specification of a simple replacement operation. @@ -147,6 +147,8 @@ impl Rewrite for SimpleReplacement { } // 3.4. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 // to p1. + let mut disconnects: Vec<(Node, IncomingPort)> = Vec::new(); + let mut connect: Vec<(Node, OutgoingPort, Node, IncomingPort)> = Vec::new(); for ((rem_out_node, rem_out_port), &rep_out_port) in &self.nu_out { let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port)); if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { @@ -154,16 +156,27 @@ impl Rewrite for SimpleReplacement { let (rem_inp_pred_node, rem_inp_pred_port) = h .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); - h.disconnect(*rem_inp_node, *rem_inp_port); - h.disconnect(*rem_out_node, *rem_out_port); - h.connect( + // Delay connecting/disconnecting the nodes until after + // processing all nu_out entries. + disconnects.push((*rem_out_node, *rem_out_port)); + disconnects.push((*rem_out_node, *rem_out_port)); + connect.push(( rem_inp_pred_node, rem_inp_pred_port, *rem_out_node, *rem_out_port, - ); + )); } } + disconnects.into_iter().for_each(|(node, port)| { + h.disconnect(node, port); + }); + connect + .into_iter() + .for_each(|(src_node, src_port, tgt_node, tgt_port)| { + h.connect(src_node, src_port, tgt_node, tgt_port); + }); + // 3.5. Remove all nodes in self.removal and edges between them. for &node in self.subgraph.nodes() { h.remove_node(node); @@ -618,8 +631,6 @@ pub(in crate::hugr::rewrite) mod test { let (mut hugr, nodes) = dfg_hugr_copy_bools; let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap(); - println!("{}", hugr.mermaid_string()); - let [_input, output] = hugr.get_io(hugr.root()).unwrap(); let replacement = { @@ -643,7 +654,7 @@ pub(in crate::hugr::rewrite) mod test { ), ( (repl_output, IncomingPort::from(1)), - (input_not, IncomingPort::from(1)), + (input_not, IncomingPort::from(0)), ), ] .into_iter() From 818e067afa90514a69d30f7a587267cc7c949a36 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 13 Jun 2024 10:01:20 +0100 Subject: [PATCH 3/4] Clarify the docs --- hugr-core/src/hugr/rewrite/simple_replace.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index c758f7a71..6f29c45d6 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -1,6 +1,6 @@ //! Implementation of the `SimpleReplace` operation. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite}; @@ -147,8 +147,10 @@ impl Rewrite for SimpleReplacement { } // 3.4. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 // to p1. - let mut disconnects: Vec<(Node, IncomingPort)> = Vec::new(); - let mut connect: Vec<(Node, OutgoingPort, Node, IncomingPort)> = Vec::new(); + // + // i.e. the replacement graph has direct edges between the input and output nodes. + let mut disconnect: HashSet<(Node, IncomingPort)> = HashSet::new(); + let mut connect: HashSet<(Node, OutgoingPort, Node, IncomingPort)> = HashSet::new(); for ((rem_out_node, rem_out_port), &rep_out_port) in &self.nu_out { let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port)); if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { @@ -158,9 +160,9 @@ impl Rewrite for SimpleReplacement { .unwrap(); // Delay connecting/disconnecting the nodes until after // processing all nu_out entries. - disconnects.push((*rem_out_node, *rem_out_port)); - disconnects.push((*rem_out_node, *rem_out_port)); - connect.push(( + disconnect.insert((*rem_out_node, *rem_out_port)); + disconnect.insert((*rem_out_node, *rem_out_port)); + connect.insert(( rem_inp_pred_node, rem_inp_pred_port, *rem_out_node, @@ -168,7 +170,7 @@ impl Rewrite for SimpleReplacement { )); } } - disconnects.into_iter().for_each(|(node, port)| { + disconnect.into_iter().for_each(|(node, port)| { h.disconnect(node, port); }); connect From 9071f9c81b01dabac05865d3208453acef721377 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Mon, 17 Jun 2024 10:19:24 +0100 Subject: [PATCH 4/4] Add comment about why the fix is needed --- hugr-core/src/hugr/rewrite/simple_replace.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 6f29c45d6..40cad9be4 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -160,7 +160,10 @@ impl Rewrite for SimpleReplacement { .unwrap(); // Delay connecting/disconnecting the nodes until after // processing all nu_out entries. - disconnect.insert((*rem_out_node, *rem_out_port)); + // + // Otherwise, we might disconnect other wires in `rem_inp_node` + // that are needed for the following iterations. + disconnect.insert((*rem_inp_node, *rem_inp_port)); disconnect.insert((*rem_out_node, *rem_out_port)); connect.insert(( rem_inp_pred_node,