diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 0da54c5fe..ec5790a75 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -1,3 +1,6 @@ +use std::fs::File; +use std::io::BufReader; + use cool_asserts::assert_matches; use rstest::rstest; @@ -19,7 +22,7 @@ use crate::std_extensions::logic::test::{and_op, or_op}; use crate::std_extensions::logic::{self, NotOp}; use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow}; -use crate::{const_extension_ids, type_row, Direction, IncomingPort, Node}; +use crate::{const_extension_ids, test_file, type_row, Direction, IncomingPort, Node}; const NAT: Type = crate::extension::prelude::USIZE_T; @@ -926,6 +929,13 @@ fn cfg_children_restrictions() { b.remove_node(exit2); // Change the types in the BasicBlock node to work on qubits instead of bits + b.replace_op( + cfg, + ops::CFG { + signature: FunctionType::new(type_row![QB_T], type_row![BOOL_T]), + }, + ) + .unwrap(); b.replace_op( block, ops::DataflowBlock { @@ -989,6 +999,23 @@ fn cfg_connections() -> Result<(), Box> { Ok(()) } +#[test] +fn cfg_entry_io_bug() -> Result<(), Box> { + // load test file where input node of entry block has types in reversed + // order compared to parent CFG node. + let mut hugr: Hugr = serde_json::from_reader(BufReader::new( + File::open(test_file!("issue-1189.json")).unwrap(), + )) + .unwrap(); + assert_matches!( + hugr.update_validate(&PRELUDE_REGISTRY), + Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch{..}, .. }) + => assert_eq!(parent, hugr.root()) + ); + + Ok(()) +} + #[cfg(feature = "extension_inference")] mod extension_tests { use self::ops::handle::{BasicBlockID, TailLoopID}; diff --git a/hugr-core/src/ops/validate.rs b/hugr-core/src/ops/validate.rs index aaee46049..8f2e077fb 100644 --- a/hugr-core/src/ops/validate.rs +++ b/hugr-core/src/ops/validate.rs @@ -12,7 +12,7 @@ use thiserror::Error; use crate::types::TypeRow; -use super::dataflow::DataflowParent; +use super::dataflow::{DataflowOpTrait, DataflowParent}; use super::{impl_validate_op, BasicBlock, ExitBlock, OpTag, OpTrait, OpType, ValidateOp}; /// A set of property flags required for an operation. @@ -121,9 +121,37 @@ impl ValidateOp for super::CFG { fn validate_op_children<'a>( &self, - children: impl Iterator, + mut children: impl Iterator, ) -> Result<(), ChildrenValidationError> { - for (child, optype) in children.dropping(2) { + let (entry, entry_op) = children.next().unwrap(); + let (exit, exit_op) = children.next().unwrap(); + let entry_op = entry_op + .as_dataflow_block() + .expect("Child check should have already checked valid ops."); + let exit_op = exit_op + .as_exit_block() + .expect("Child check should have already checked valid ops."); + + let sig = self.signature(); + if entry_op.inner_signature().input() != sig.input() { + return Err(ChildrenValidationError::IOSignatureMismatch { + child: entry, + actual: entry_op.inner_signature().input().clone(), + expected: sig.input().clone(), + node_desc: "BasicBlock Input", + container_desc: "CFG", + }); + } + if &exit_op.cfg_outputs != sig.output() { + return Err(ChildrenValidationError::IOSignatureMismatch { + child: exit, + actual: exit_op.cfg_outputs.clone(), + expected: sig.output().clone(), + node_desc: "BasicBlockExit Output", + container_desc: "CFG", + }); + } + for (child, optype) in children { if optype.tag() == OpTag::BasicBlockExit { return Err(ChildrenValidationError::InternalExitChildren { child }); } diff --git a/resources/test/issue-1189.json b/resources/test/issue-1189.json new file mode 100644 index 000000000..0b36c92dc --- /dev/null +++ b/resources/test/issue-1189.json @@ -0,0 +1,136 @@ +{ + "version": "v1", + "nodes": [ + { + "parent": 0, + "op": "CFG", + "signature": { + "t": "G", + "input": [ + { + "t": "Sum", + "s": "Unit", + "size": 1 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "extension_reqs": [] + } + }, + { + "parent": 0, + "op": "DataflowBlock", + "inputs": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 1 + } + ], + "other_outputs": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "sum_rows": [ + [] + ], + "extension_delta": [] + }, + { + "parent": 1, + "op": "Input", + "types": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 1 + } + ] + }, + { + "parent": 1, + "op": "Output", + "types": [ + { + "t": "Sum", + "s": "Unit", + "size": 1 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ] + }, + { + "parent": 0, + "op": "ExitBlock", + "cfg_outputs": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ] + } + ], + "edges": [ + [ + [ + 2, + 1 + ], + [ + 3, + 0 + ] + ], + [ + [ + 2, + 0 + ], + [ + 3, + 1 + ] + ], + [ + [ + 1, + 0 + ], + [ + 4, + 0 + ] + ] + ], + "metadata": null, + "encoder": "hugr-py v0.2.1" +}