Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Replace NodeType::signature() with io_extensions() #700

Merged
merged 2 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl UnificationContext {
where
T: HugrView,
{
if hugr.root_type().signature().is_none() {
if hugr.root_type().input_extensions().is_none() {
let m_input = self.make_or_get_meta(hugr.root(), Direction::Incoming);
self.variables.insert(m_input);
}
Expand Down Expand Up @@ -302,7 +302,7 @@ impl UnificationContext {
self.add_constraint(m_output, Constraint::Equal(m_exit));
}

match node_type.signature() {
match node_type.io_extensions() {
// Input extensions are open
None => {
let c = if let Some(sig) = node_type.op_signature() {
Expand All @@ -318,9 +318,9 @@ impl UnificationContext {
self.add_constraint(m_output, c);
}
// We have a solution for everything!
Some(sig) => {
self.add_solution(m_output, sig.output_extensions());
self.add_solution(m_input, sig.input_extensions);
Some((input_exts, output_exts)) => {
self.add_solution(m_input, input_exts.clone());
self.add_solution(m_output, output_exts);
}
}
}
Expand Down
36 changes: 7 additions & 29 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,26 +253,10 @@ fn dangling_src() -> Result<(), Box<dyn Error>> {

let closure = hugr.infer_extensions()?;
assert!(closure.is_empty());
assert_eq!(hugr.get_nodetype(src.node()).io_extensions().unwrap().1, rs);
assert_eq!(
hugr.get_nodetype(src.node())
.signature()
.unwrap()
.output_extensions(),
rs
);
assert_eq!(
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.input_extensions,
rs
);
assert_eq!(
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.output_extensions(),
rs
hugr.get_nodetype(mult.node()).io_extensions().unwrap(),
(&rs.clone(), rs)
);
Ok(())
}
Expand Down Expand Up @@ -385,18 +369,12 @@ fn test_conditional_inference() -> Result<(), Box<dyn Error>> {

for node in [case0_node, case1_node, conditional_node] {
assert_eq!(
hugr.get_nodetype(node)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
hugr.get_nodetype(node).io_extensions().unwrap().0,
&ExtensionSet::new()
);
assert_eq!(
hugr.get_nodetype(node)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
hugr.get_nodetype(node).io_extensions().unwrap().0,
&ExtensionSet::new()
);
}
Ok(())
Expand Down
16 changes: 9 additions & 7 deletions src/extension/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ impl ExtensionValidator {
/// extension requirements for all of its input and output edges, then put
/// those requirements in the extension validation context.
fn gather_extensions(&mut self, node: &Node, node_type: &NodeType) {
if let Some(sig) = node_type.signature() {
for dir in Direction::BOTH {
assert!(self
.extensions
.insert((*node, dir), sig.get_extension(&dir))
.is_none());
}
if let Some((input_exts, output_exts)) = node_type.io_extensions() {
let prev_i = self
.extensions
.insert((*node, Direction::Incoming), input_exts.clone());
assert!(prev_i.is_none());
let prev_o = self
.extensions
.insert((*node, Direction::Outgoing), output_exts);
assert!(prev_o.is_none());
}
}

Expand Down
44 changes: 21 additions & 23 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::extension::{
};
use crate::ops::custom::resolve_extension_ops;
use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE};
use crate::types::{FunctionType, Signature};
use crate::types::FunctionType;
use crate::{Direction, Node};

use delegate::delegate;
Expand Down Expand Up @@ -109,16 +109,6 @@ impl NodeType {
}
}

/// Use the input extensions to calculate the concrete signature of the node
pub fn signature(&self) -> Option<Signature> {
self.input_extensions.as_ref().map(|rs| {
self.op
.dataflow_signature()
.unwrap_or_default()
.with_input_extensions(rs.clone())
})
}

/// Get the function type from the embedded op
pub fn op_signature(&self) -> Option<FunctionType> {
self.op.dataflow_signature()
Expand All @@ -134,6 +124,23 @@ impl NodeType {
self.input_extensions.as_ref()
}

/// The input and output extensions for this node, if set.
///
/// `None`` if the [Self::input_extensions] is `None`.
/// Otherwise, will return Some, with the output extensions computed from the node's delta
pub fn io_extensions(&self) -> Option<(&ExtensionSet, ExtensionSet)> {
self.input_extensions.as_ref().map(|e| {
(
e,
self.op
.dataflow_signature()
.map(|ft| ft.extension_reqs)
.unwrap_or_default()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a change to the semantics introduced in #680 but I believe closer to what we had before - where a non-dataflow node would return FunctionType::default() (rather than None in #680) which meant the empty set of extension_reqs

.union(e),
)
})
}

/// Gets the underlying [OpType] i.e. without any [input_extensions]
///
/// [input_extensions]: NodeType::input_extensions
Expand Down Expand Up @@ -411,19 +418,10 @@ mod test {
hugr.infer_extensions()?;

assert_eq!(
hugr.get_nodetype(lift)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
);
assert_eq!(
hugr.get_nodetype(output)
.signature()
.unwrap()
.input_extensions,
r
hugr.get_nodetype(lift).input_extensions().unwrap(),
&ExtensionSet::new()
);
assert_eq!(hugr.get_nodetype(output).input_extensions().unwrap(), &r);
Ok(())
}
}