diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index deecf1cbb..0dfd9a13b 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -405,6 +405,13 @@ pub enum ExtensionBuildError { #[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct ExtensionSet(BTreeSet); +/// A special ExtensionId which indicates that the delta of a non-Function +/// container node should be computed by extension inference. See [`infer_extensions`] +/// which lists the container nodes to which this can be applied. +/// +/// [`infer_extensions`]: crate::hugr::Hugr::infer_extensions +pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked(".TO_BE_INFERRED"); + impl ExtensionSet { /// Creates a new empty extension set. pub const fn new() -> Self { diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 25aa28214..978b575c5 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -13,7 +13,6 @@ use std::collections::VecDeque; use std::iter; pub(crate) use self::hugrmut::HugrMut; -use self::validate::ExtensionError; pub use self::validate::ValidationError; pub use ident::{IdentList, InvalidIdentifier}; @@ -25,9 +24,9 @@ use thiserror::Error; pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; -use crate::extension::ExtensionRegistry; +use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; use crate::ops::custom::resolve_extension_ops; -use crate::ops::OpTag; +use crate::ops::{OpTag, OpTrait}; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; use crate::{Direction, Node}; @@ -92,15 +91,80 @@ impl Hugr { self.validate_no_extensions(extension_registry)?; #[cfg(feature = "extension_inference")] { - self.infer_extensions()?; + self.infer_extensions(false)?; self.validate_extensions()?; } Ok(()) } - /// Leaving this here as in the future we plan for it to infer deltas - /// of container nodes e.g. [OpType::DFG]. For the moment it does nothing. - pub fn infer_extensions(&mut self) -> Result<(), ExtensionError> { + /// Infers an extension-delta for any non-function container node + /// whose current [extension_delta] contains [TO_BE_INFERRED]. The inferred delta + /// will be the smallest delta compatible with its children and that includes any + /// other [ExtensionId]s in the current delta. + /// + /// If `remove` is true, for such container nodes *without* [TO_BE_INFERRED], + /// ExtensionIds are removed from the delta if they are *not* used by any child node. + /// + /// The non-function container nodes are: + /// [Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop] + /// + /// [Case]: crate::ops::Case + /// [CFG]: crate::ops::CFG + /// [Conditional]: crate::ops::Conditional + /// [DataflowBlock]: crate::ops::DataflowBlock + /// [DFG]: crate::ops::DFG + /// [TailLoop]: crate::ops::TailLoop + /// [extension_delta]: crate::ops::OpType::extension_delta + /// [ExtensionId]: crate::extension::ExtensionId + pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> { + fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> { + match optype { + OpType::DFG(dfg) => Some(&mut dfg.signature.extension_reqs), + OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta), + OpType::TailLoop(tl) => Some(&mut tl.extension_delta), + OpType::CFG(cfg) => Some(&mut cfg.signature.extension_reqs), + OpType::Conditional(c) => Some(&mut c.extension_delta), + OpType::Case(c) => Some(&mut c.signature.extension_reqs), + //OpType::Lift(_) // Not ATM: only a single element, and we expect Lift to be removed + //OpType::FuncDefn(_) // Not at present due to the possibility of recursion + _ => None, + } + } + fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result { + let mut child_sets = h + .children(node) + .collect::>() // Avoid borrowing h over recursive call + .into_iter() + .map(|ch| Ok((ch, infer(h, ch, remove)?))) + .collect::, _>>()?; + + let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else { + return Ok(h.get_optype(node).extension_delta()); + }; + if es.contains(&TO_BE_INFERRED) { + // Do not remove anything from current delta - any other elements are a lower bound + child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst + } else if remove { + child_sets.iter().try_for_each(|(ch, ch_exts)| { + if !es.is_superset(ch_exts) { + return Err(ExtensionError { + parent: node, + parent_extensions: es.clone(), + child: *ch, + child_extensions: ch_exts.clone(), + }); + } + Ok(()) + })?; + } else { + return Ok(es.clone()); // Can't neither add nor remove, so nothing to do + } + let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); + *es = ExtensionSet::singleton(&TO_BE_INFERRED).missing_from(&merged); + + Ok(es.clone()) + } + infer(self, self.root(), remove)?; Ok(()) } } @@ -203,6 +267,16 @@ impl Hugr { } } +#[derive(Debug, Clone, PartialEq, Error)] +#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")] +/// An error in the extension deltas. +pub struct ExtensionError { + parent: Node, + parent_extensions: ExtensionSet, + child: Node, + child_extensions: ExtensionSet, +} + /// Errors that can occur while manipulating a Hugr. /// /// TODO: Better descriptions, not just re-exporting portgraph errors. @@ -221,7 +295,14 @@ pub enum HugrError { #[cfg(test)] mod test { - use super::{Hugr, HugrView}; + use super::internal::HugrMutInternals; + #[cfg(feature = "extension_inference")] + use super::ValidationError; + use super::{ExtensionError, Hugr, HugrMut, HugrView, Node}; + use crate::extension::{ExtensionId, ExtensionSet, EMPTY_REG, TO_BE_INFERRED}; + use crate::types::{FunctionType, Type}; + use crate::{const_extension_ids, ops, type_row}; + use rstest::rstest; #[test] fn impls_send_and_sync() { @@ -240,4 +321,167 @@ mod test { let hugr = simple_dfg_hugr(); assert_matches!(hugr.get_io(hugr.root()), Some(_)); } + + const_extension_ids! { + const XA: ExtensionId = "EXT_A"; + const XB: ExtensionId = "EXT_B"; + } + + #[rstest] + #[case([], XA.into())] + #[case([XA], XA.into())] + #[case([XB], ExtensionSet::from_iter([XA, XB]))] + + fn infer_single_delta( + #[case] parent: impl IntoIterator, + #[values(true, false)] remove: bool, // makes no difference when inferring + #[case] result: ExtensionSet, + ) { + let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into()); + let (mut h, _) = build_ext_dfg(parent); + h.infer_extensions(remove).unwrap(); + assert_eq!(h, build_ext_dfg(result).0); + } + + #[test] + fn infer_removes_from_delta() { + let parent = ExtensionSet::from_iter([XA, XB]); + let mut h = build_ext_dfg(parent.clone()).0; + let backup = h.clone(); + h.infer_extensions(false).unwrap(); + assert_eq!(h, backup); // did nothing + h.infer_extensions(true).unwrap(); + assert_eq!(h, build_ext_dfg(XA.into()).0); + } + + #[test] + fn infer_bad_remove() { + let (mut h, mid) = build_ext_dfg(XB.into()); + let backup = h.clone(); + h.infer_extensions(false).unwrap(); + assert_eq!(h, backup); // did nothing + let val_res = h.validate(&EMPTY_REG); + let expected_err = ExtensionError { + parent: h.root(), + parent_extensions: XB.into(), + child: mid, + child_extensions: XA.into(), + }; + #[cfg(feature = "extension_inference")] + assert_eq!( + val_res, + Err(ValidationError::ExtensionError(expected_err.clone())) + ); + #[cfg(not(feature = "extension_inference"))] + assert!(val_res.is_ok()); + + let inf_res = h.infer_extensions(true); + assert_eq!(inf_res, Err(expected_err)); + } + + fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) { + let ty = Type::new_function(FunctionType::new_endo(type_row![])); + let mut h = Hugr::new(ops::DFG { + signature: FunctionType::new_endo(ty.clone()).with_extension_delta(parent.clone()), + }); + let root = h.root(); + let mid = add_inliftout(&mut h, root, ty); + (h, mid) + } + + fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node { + let inp = h.add_node_with_parent( + p, + ops::Input { + types: ty.clone().into(), + }, + ); + let out = h.add_node_with_parent( + p, + ops::Output { + types: ty.clone().into(), + }, + ); + let mid = h.add_node_with_parent( + p, + ops::Lift { + type_row: ty.into(), + new_extension: XA, + }, + ); + h.connect(inp, 0, mid, 0); + h.connect(mid, 0, out, 0); + mid + } + + #[rstest] + // Base case success: delta inferred for parent equals grandparent. + #[case([XA], [TO_BE_INFERRED], true, [XA])] + // Success: delta inferred for parent is subset of grandparent + #[case([XA, XB], [TO_BE_INFERRED], true, [XA])] + // Base case failure: infers [XA] for parent but grandparent has disjoint set + #[case([XB], [TO_BE_INFERRED], false, [XA])] + // Failure: as previous, but extra "lower bound" on parent that has no effect + #[case([XB], [XA, TO_BE_INFERRED], false, [XA])] + // Failure: grandparent ok wrt. child but parent specifies extra lower-bound XB + #[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])] + // Success: grandparent includes extra XB required for parent's "lower bound" + #[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])] + // Success: grandparent is also inferred so can include 'extra' XB from parent + #[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])] + // No inference: extraneous XB in parent is removed so all become [XA]. + #[case([XA], [XA, XB], true, [XA])] + fn infer_three_generations( + #[case] grandparent: impl IntoIterator, + #[case] parent: impl IntoIterator, + #[case] success: bool, + #[case] result: impl IntoIterator, + ) { + let ty = Type::new_function(FunctionType::new_endo(type_row![])); + let grandparent = ExtensionSet::from_iter(grandparent); + let result = ExtensionSet::from_iter(result); + let root_ty = ops::Conditional { + sum_rows: vec![type_row![]], + other_inputs: ty.clone().into(), + outputs: ty.clone().into(), + extension_delta: grandparent.clone(), + }; + let mut h = Hugr::new(root_ty.clone()); + let p = h.add_node_with_parent( + h.root(), + ops::Case { + signature: FunctionType::new_endo(ty.clone()) + .with_extension_delta(ExtensionSet::from_iter(parent)), + }, + ); + add_inliftout(&mut h, p, ty.clone()); + assert!(h.validate_extensions().is_err()); + let backup = h.clone(); + let inf_res = h.infer_extensions(true); + if success { + assert!(inf_res.is_ok()); + let expected_p = ops::Case { + signature: FunctionType::new_endo(ty).with_extension_delta(result.clone()), + }; + let mut expected = backup; + expected.replace_op(p, expected_p).unwrap(); + let expected_gp = ops::Conditional { + extension_delta: result, + ..root_ty + }; + expected.replace_op(h.root(), expected_gp).unwrap(); + + assert_eq!(h, expected); + } else { + assert_eq!( + inf_res, + Err(ExtensionError { + parent: h.root(), + parent_extensions: grandparent, + child: p, + child_extensions: result + }) + ); + } + } } diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index be2e06003..c36169134 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,7 +9,7 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; +use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED}; use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; @@ -19,6 +19,7 @@ use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, Hugr, Node, Port}; use super::views::{HierarchyView, HugrView, SiblingGraph}; +use super::ExtensionError; /// Structure keeping track of pre-computed information used in the validation /// process. @@ -59,6 +60,9 @@ impl Hugr { pub fn validate_extensions(&self) -> Result<(), ValidationError> { for parent in self.nodes() { let parent_op = self.get_optype(parent); + if parent_op.extension_delta().contains(&TO_BE_INFERRED) { + return Err(ValidationError::ExtensionsNotInferred { node: parent }); + } let parent_extensions = match parent_op.inner_function_type() { Some(FunctionType { extension_reqs, .. }) => extension_reqs, None => match parent_op.tag() { @@ -743,6 +747,9 @@ pub enum ValidationError { /// There are errors in the extension deltas. #[error(transparent)] ExtensionError(#[from] ExtensionError), + /// A node claims to still be awaiting extension inference. Perhaps it is not acted upon by inference. + #[error("Node {node:?} needs a concrete ExtensionSet - inference will provide this for Case/CFG/Conditional/DataflowBlock/DFG/TailLoop only")] + ExtensionsNotInferred { node: Node }, /// Error in a node signature #[error("Error in signature of node {node:?}: {cause}")] SignatureError { node: Node, cause: SignatureError }, @@ -815,15 +822,5 @@ pub enum InterGraphEdgeError { }, } -#[derive(Debug, Clone, PartialEq, Error)] -#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")] -/// An error in the extension deltas. -pub struct ExtensionError { - parent: Node, - parent_extensions: ExtensionSet, - child: Node, - child_extensions: ExtensionSet, -} - #[cfg(test)] pub(crate) mod test;