diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index 02b5a25b5..770d56a5f 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -8,6 +8,8 @@ pub mod serialize; pub mod validate; pub mod views; +#[cfg(feature = "extension_inference")] +use std::collections::HashMap; use std::collections::VecDeque; use std::iter; @@ -196,8 +198,12 @@ impl Hugr { extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { resolve_extension_ops(self, extension_registry)?; - self.infer_extensions()?; - self.validate(extension_registry)?; + self.validate_no_extensions(extension_registry)?; + #[cfg(feature = "extension_inference")] + { + self.infer_extensions()?; + self.validate_extensions(HashMap::new())?; + } Ok(()) } diff --git a/hugr/src/hugr/validate.rs b/hugr/src/hugr/validate.rs index 26857495a..40de5d805 100644 --- a/hugr/src/hugr/validate.rs +++ b/hugr/src/hugr/validate.rs @@ -35,9 +35,6 @@ struct ValidationContext<'a, 'b> { hugr: &'a Hugr, /// Dominator tree for each CFG region, using the container node as index. dominators: HashMap>, - /// Context for the extension validation. - #[allow(dead_code)] - extension_validator: ExtensionValidator, /// Registry of available Extensions extension_registry: &'b ExtensionRegistry, } @@ -48,7 +45,51 @@ impl Hugr { /// TODO: Add a version of validation which allows for open extension /// variables (see github issue #457) pub fn validate(&self, extension_registry: &ExtensionRegistry) -> Result<(), ValidationError> { - self.validate_with_extension_closure(HashMap::new(), extension_registry) + #[cfg(feature = "extension_inference")] + self.validate_with_extension_closure(HashMap::new(), extension_registry)?; + #[cfg(not(feature = "extension_inference"))] + self.validate_no_extensions(extension_registry)?; + Ok(()) + } + + /// Check the validity of the HUGR, but don't check consistency of extension + /// requirements between connected nodes or between parents and children. + pub fn validate_no_extensions( + &self, + extension_registry: &ExtensionRegistry, + ) -> Result<(), ValidationError> { + let mut validator = ValidationContext::new(self, extension_registry); + validator.validate() + } + + /// Validate extensions on the input and output edges of nodes. Check that + /// the target ends of edges require the extensions from the sources, and + /// check extension deltas from parent nodes are reflected in their children + pub fn validate_extensions(&self, closure: ExtensionSolution) -> Result<(), ValidationError> { + let validator = ExtensionValidator::new(self, closure); + for src_node in self.nodes() { + let node_type = self.get_nodetype(src_node); + + // FuncDefns have no resources since they're static nodes, but the + // functions they define can have any extension delta. + if node_type.tag() != OpTag::FuncDefn { + // If this is a container with I/O nodes, check that the extension they + // define match the extensions of the container. + if let Some([input, output]) = self.get_io(src_node) { + validator.validate_io_extensions(src_node, input, output)?; + } + } + + for src_port in self.node_outputs(src_node) { + for (tgt_node, tgt_port) in self.linked_inputs(src_node, src_port) { + validator.check_extensions_compatible( + &(src_node, src_port.into()), + &(tgt_node, tgt_port.into()), + )?; + } + } + } + Ok(()) } /// Check the validity of a hugr, taking an argument of a closure for the @@ -58,8 +99,10 @@ impl Hugr { closure: ExtensionSolution, extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { - let mut validator = ValidationContext::new(self, closure, extension_registry); - validator.validate() + let mut validator = ValidationContext::new(self, extension_registry); + validator.validate()?; + self.validate_extensions(closure)?; + Ok(()) } } @@ -68,15 +111,10 @@ impl<'a, 'b> ValidationContext<'a, 'b> { // Allow unused "extension_closure" variable for when // the "extension_inference" feature is disabled. #[allow(unused_variables)] - pub fn new( - hugr: &'a Hugr, - extension_closure: ExtensionSolution, - extension_registry: &'b ExtensionRegistry, - ) -> Self { + pub fn new(hugr: &'a Hugr, extension_registry: &'b ExtensionRegistry) -> Self { Self { hugr, dominators: HashMap::new(), - extension_validator: ExtensionValidator::new(hugr, extension_closure), extension_registry, } } @@ -176,18 +214,6 @@ impl<'a, 'b> ValidationContext<'a, 'b> { // Secondly that the node has correct children self.validate_children(node, node_type)?; - // FuncDefns have no resources since they're static nodes, but the - // functions they define can have any extension delta. - #[cfg(feature = "extension_inference")] - if node_type.tag() != OpTag::FuncDefn { - // If this is a container with I/O nodes, check that the extension they - // define match the extensions of the container. - if let Some([input, output]) = self.hugr.get_io(node) { - self.extension_validator - .validate_io_extensions(node, input, output)?; - } - } - Ok(()) } @@ -247,10 +273,6 @@ impl<'a, 'b> ValidationContext<'a, 'b> { let other_node: Node = self.hugr.graph.port_node(link).unwrap().into(); let other_offset = self.hugr.graph.port_offset(link).unwrap().into(); - #[cfg(feature = "extension_inference")] - self.extension_validator - .check_extensions_compatible(&(node, port), &(other_node, other_offset))?; - let other_op = self.hugr.get_optype(other_node); let Some(other_kind) = other_op.port_kind(other_offset) else { panic!("The number of ports in {other_node} does not match the operation definition. This should have been caught by `validate_node`.");