Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: FuncDefns don't require that their extensions match their children #688

Merged
merged 12 commits into from
Nov 16, 2023
14 changes: 13 additions & 1 deletion src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,19 @@ impl UnificationContext {
let m_input_node = self.make_or_get_meta(input, dir);
self.add_constraint(m_input_node, Constraint::Equal(m_input));
let m_output_node = self.make_or_get_meta(output, dir);
self.add_constraint(m_output_node, Constraint::Equal(m_output));
// If the parent node is a FuncDefn, it will have no
// op_signature, so the Incoming and Outgoing ports will
// have equal extension requirements.
// The function that it contains, however, may have an
// extension delta, so its output shouldn't be equal to the
// FuncDefn's output.
//
// TODO: Add a constraint that the extensions of the output
// node of a FuncDefn should be those of the input node plus
// the extension delta specified in the function signature.
if node_type.tag() != OpTag::FuncDefn {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, that is neater than I realized, nice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hang on. (Hang on!) Don't we need an additional constraint, then, that the (input-extensions to the) Output node are Plus of (the delta of the FuncDefn's declared FunctionType and the Input node)?

Could add a test of a FuncDefn that declares delta is {A} and then has a body that Lift's only by {B}, say - that should be rejected, but is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a test for this "funcdefn_signature_mismatch"

self.add_constraint(m_output_node, Constraint::Equal(m_output));
}
}
}

Expand Down
84 changes: 83 additions & 1 deletion src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::error::Error;

use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::QB_T;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
Expand Down Expand Up @@ -962,3 +964,83 @@ fn sccs() {
Some(&ExtensionSet::from_iter([A, B, C, UNKNOWN_EXTENSION]))
);
}

#[test]
/// Note: This test is relying on the builder's `define_function` doing the
/// right thing: it takes input resources via a [`Signature`], which it passes
/// to `create_with_io`, creating concrete resource sets.
/// Inference can still fail for a valid FuncDefn hugr created without using
/// the builder API.
fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
let mut func_builder = builder.define_function(
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
)?;

let [w] = func_builder.input_wires_arr();
let lift = func_builder.add_dataflow_op(
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: A,
},
[w],
)?;
let [w] = lift.outputs_arr();
func_builder.finish_with_outputs([w])?;
builder.finish_prelude_hugr()?;
Ok(())
}

#[test]
fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
let mut func_builder = builder.define_function(
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
)?;

let [w] = func_builder.input_wires_arr();
let lift = func_builder.add_dataflow_op(
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: B,
},
[w],
)?;
let [w] = lift.outputs_arr();
func_builder.finish_with_outputs([w])?;
let result = builder.finish_prelude_hugr();
assert_matches!(
result,
Err(ValidationError::CantInfer(
InferExtensionError::MismatchedConcreteWithLocations { .. }
))
);
Ok(())
}

#[test]
// Test that the difference between a FuncDefn's input and output nodes is being
// constrained to be the same as the extension delta in the FuncDefn signature.
// The FuncDefn here is declared to add resource "A", but its body just wires
// the input to the output.
fn funcdefn_signature_mismatch2() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
let func_builder = builder.define_function(
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
)?;

let [w] = func_builder.input_wires_arr();
func_builder.finish_with_outputs([w])?;
let result = builder.finish_prelude_hugr();
assert_matches!(result, Err(ValidationError::CantInfer(..)));
Ok(())
}
14 changes: 9 additions & 5 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,15 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
// Secondly that the node has correct children
self.validate_children(node, node_type)?;

// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
if let Some([input, output]) = self.hugr.get_io(node) {
self.extension_validator
.validate_io_extensions(node, input, output)?;
// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
if node_type.tag() != OpTag::FuncDefn {
croyzor marked this conversation as resolved.
Show resolved Hide resolved
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
if let Some([input, output]) = self.hugr.get_io(node) {
self.extension_validator
.validate_io_extensions(node, input, output)?;
}
}

Ok(())
Expand Down
52 changes: 51 additions & 1 deletion src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ fn extensions_mismatch() -> Result<(), BuildError> {
assert_matches!(
handle,
Err(ValidationError::ExtensionError(
ExtensionError::ParentIOExtensionMismatch { .. }
ExtensionError::TgtExceedsSrcExtensionsAtPort { .. }
))
);
Ok(())
Expand Down Expand Up @@ -752,3 +752,53 @@ fn invalid_types() {
SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1))
);
}

#[test]
fn parent_io_mismatch() {
// The DFG node declares that it has an empty extension delta,
// but it's child graph adds extension "XB", causing a mismatch.
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: FunctionType::new(type_row![USIZE_T], type_row![USIZE_T]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem here is that the FunctionType has zero delta? Whereas if you did .with_extension_delta(...XB_T...) it'd pass? (Might be worth doing both ways via a loop / helper fn just to show the difference between working and not working)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment explaining this

}));

let input = hugr
.add_node_with_parent(
hugr.root(),
NodeType::new_pure(ops::Input {
types: type_row![USIZE_T],
}),
)
.unwrap();
let output = hugr
.add_node_with_parent(
hugr.root(),
NodeType::new(
ops::Output {
types: type_row![USIZE_T],
},
ExtensionSet::singleton(&XB),
),
)
.unwrap();

let lift = hugr
.add_node_with_parent(
hugr.root(),
NodeType::new_pure(ops::LeafOp::Lift {
type_row: type_row![USIZE_T],
new_extension: XB,
}),
)
.unwrap();

hugr.connect(input, 0, lift, 0).unwrap();
hugr.connect(lift, 0, output, 0).unwrap();

let result = hugr.validate(&PRELUDE_REGISTRY);
assert_matches!(
result,
Err(ValidationError::ExtensionError(
ExtensionError::ParentIOExtensionMismatch { .. }
))
);
}