diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 806d55726..bcea6175f 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -261,7 +261,7 @@ impl UnificationContext { where T: HugrView, { - if hugr.root_type().signature().is_none() { + if hugr.root_type().input_extensions().is_none() { let m_input = self.make_or_get_meta(hugr.root(), Direction::Incoming); self.variables.insert(m_input); } @@ -302,7 +302,7 @@ impl UnificationContext { self.add_constraint(m_output, Constraint::Equal(m_exit)); } - match node_type.signature() { + match node_type.io_extensions() { // Input extensions are open None => { let c = if let Some(sig) = node_type.op_signature() { @@ -318,9 +318,9 @@ impl UnificationContext { self.add_constraint(m_output, c); } // We have a solution for everything! - Some(sig) => { - self.add_solution(m_output, sig.output_extensions()); - self.add_solution(m_input, sig.input_extensions); + Some((input_exts, output_exts)) => { + self.add_solution(m_input, input_exts.clone()); + self.add_solution(m_output, output_exts); } } } diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 7a30da09d..0315ae353 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -253,26 +253,10 @@ fn dangling_src() -> Result<(), Box> { let closure = hugr.infer_extensions()?; assert!(closure.is_empty()); + assert_eq!(hugr.get_nodetype(src.node()).io_extensions().unwrap().1, rs); assert_eq!( - hugr.get_nodetype(src.node()) - .signature() - .unwrap() - .output_extensions(), - rs - ); - assert_eq!( - hugr.get_nodetype(mult.node()) - .signature() - .unwrap() - .input_extensions, - rs - ); - assert_eq!( - hugr.get_nodetype(mult.node()) - .signature() - .unwrap() - .output_extensions(), - rs + hugr.get_nodetype(mult.node()).io_extensions().unwrap(), + (&rs.clone(), rs) ); Ok(()) } @@ -385,18 +369,12 @@ fn test_conditional_inference() -> Result<(), Box> { for node in [case0_node, case1_node, conditional_node] { assert_eq!( - hugr.get_nodetype(node) - .signature() - .unwrap() - .input_extensions, - ExtensionSet::new() + hugr.get_nodetype(node).io_extensions().unwrap().0, + &ExtensionSet::new() ); assert_eq!( - hugr.get_nodetype(node) - .signature() - .unwrap() - .input_extensions, - ExtensionSet::new() + hugr.get_nodetype(node).io_extensions().unwrap().0, + &ExtensionSet::new() ); } Ok(()) diff --git a/src/extension/validate.rs b/src/extension/validate.rs index 2f267edf9..c66f777fa 100644 --- a/src/extension/validate.rs +++ b/src/extension/validate.rs @@ -50,13 +50,15 @@ impl ExtensionValidator { /// extension requirements for all of its input and output edges, then put /// those requirements in the extension validation context. fn gather_extensions(&mut self, node: &Node, node_type: &NodeType) { - if let Some(sig) = node_type.signature() { - for dir in Direction::BOTH { - assert!(self - .extensions - .insert((*node, dir), sig.get_extension(&dir)) - .is_none()); - } + if let Some((input_exts, output_exts)) = node_type.io_extensions() { + let prev_i = self + .extensions + .insert((*node, Direction::Incoming), input_exts.clone()); + assert!(prev_i.is_none()); + let prev_o = self + .extensions + .insert((*node, Direction::Outgoing), output_exts); + assert!(prev_o.is_none()); } } diff --git a/src/hugr.rs b/src/hugr.rs index 33f02dc43..474927d5b 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -31,7 +31,7 @@ use crate::extension::{ }; use crate::ops::custom::resolve_extension_ops; use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE}; -use crate::types::{FunctionType, Signature}; +use crate::types::FunctionType; use crate::{Direction, Node}; use delegate::delegate; @@ -109,16 +109,6 @@ impl NodeType { } } - /// Use the input extensions to calculate the concrete signature of the node - pub fn signature(&self) -> Option { - self.input_extensions.as_ref().map(|rs| { - self.op - .dataflow_signature() - .unwrap_or_default() - .with_input_extensions(rs.clone()) - }) - } - /// Get the function type from the embedded op pub fn op_signature(&self) -> Option { self.op.dataflow_signature() @@ -134,6 +124,23 @@ impl NodeType { self.input_extensions.as_ref() } + /// The input and output extensions for this node, if set. + /// + /// `None`` if the [Self::input_extensions] is `None`. + /// Otherwise, will return Some, with the output extensions computed from the node's delta + pub fn io_extensions(&self) -> Option<(&ExtensionSet, ExtensionSet)> { + self.input_extensions.as_ref().map(|e| { + ( + e, + self.op + .dataflow_signature() + .map(|ft| ft.extension_reqs) + .unwrap_or_default() + .union(e), + ) + }) + } + /// Gets the underlying [OpType] i.e. without any [input_extensions] /// /// [input_extensions]: NodeType::input_extensions @@ -411,19 +418,10 @@ mod test { hugr.infer_extensions()?; assert_eq!( - hugr.get_nodetype(lift) - .signature() - .unwrap() - .input_extensions, - ExtensionSet::new() - ); - assert_eq!( - hugr.get_nodetype(output) - .signature() - .unwrap() - .input_extensions, - r + hugr.get_nodetype(lift).input_extensions().unwrap(), + &ExtensionSet::new() ); + assert_eq!(hugr.get_nodetype(output).input_extensions().unwrap(), &r); Ok(()) } } diff --git a/src/types/signature.rs b/src/types/signature.rs index 6e4c6b60c..fe91ccfde 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -27,7 +27,7 @@ pub struct FunctionType { } #[derive(Clone, Default, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -/// A concrete signature, which has been instantiated with a set of input extensions +/// A combination of a FunctionType and a set of input extensions, used for declaring functions pub struct Signature { /// The underlying signature pub signature: FunctionType, @@ -243,15 +243,6 @@ impl FunctionType { } impl Signature { - /// Returns a reference to the extension set for the ports of the - /// signature in a given direction - pub fn get_extension(&self, dir: &Direction) -> ExtensionSet { - match dir { - Direction::Incoming => self.input_extensions.clone(), - Direction::Outgoing => self.output_extensions(), - } - } - delegate! { to self.signature { /// Inputs of the function type