Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: serialization round-trips #948

Merged
merged 7 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions hugr/src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ pub enum HUGRSerializationError {
#[error("Failed to build edge when deserializing: {0:?}.")]
LinkError(#[from] LinkError),
/// Edges without port offsets cannot be present in operations without non-dataflow ports.
#[error("Cannot connect an edge without port offset to node {node:?} with operation type {op_type:?}.")]
#[error("Cannot connect an {dir:?} edge without port offset to node {node:?} with operation type {op_type:?}.")]
MissingPortOffset {
/// The node that has the port without offset.
node: Node,
/// The direction of the port without an offset
dir: Direction,
/// The operation type of the node.
op_type: OpType,
},
Expand Down Expand Up @@ -231,6 +233,7 @@ impl TryFrom<SerHugrV1> for Hugr {
.other_port(dir)
.ok_or(HUGRSerializationError::MissingPortOffset {
node,
dir,
op_type: op_type.clone(),
})?
.index()
Expand Down Expand Up @@ -328,10 +331,20 @@ pub mod test {
}

/// Serialize and deserialize a HUGR, and check that the result is the same as the original.
doug-q marked this conversation as resolved.
Show resolved Hide resolved
/// Checks the serialized json against the in-tree schema.
///
/// Returns the deserialized HUGR.
pub fn check_hugr_roundtrip(hugr: &Hugr) -> Hugr {
let new_hugr: Hugr = ser_roundtrip_validate(hugr, Some(&SCHEMA));
pub fn check_hugr_schema_roundtrip(hugr: &Hugr) -> Hugr {
check_hugr_roundtrip(hugr, true)
}

/// Serialize and deserialize a HUGR, and check that the result is the same as the original.
///
/// If `check_schema` is true, checks the serialized json against the in-tree schema.
///
/// Returns the deserialized HUGR.
pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr {
doug-q marked this conversation as resolved.
Show resolved Hide resolved
let new_hugr: Hugr = ser_roundtrip_validate(hugr, check_schema.then_some(&SCHEMA));

// Original HUGR, with canonicalized node indices
//
Expand Down Expand Up @@ -417,7 +430,7 @@ pub mod test {
metadata: Default::default(),
};

check_hugr_roundtrip(&hugr);
check_hugr_schema_roundtrip(&hugr);
}

#[test]
Expand Down Expand Up @@ -451,7 +464,7 @@ pub mod test {
module_builder.finish_prelude_hugr().unwrap()
};

check_hugr_roundtrip(&hugr);
check_hugr_schema_roundtrip(&hugr);
}

#[test]
Expand All @@ -467,7 +480,7 @@ pub mod test {
}
let hugr = dfg.finish_hugr_with_outputs(params, &EMPTY_REG)?;

check_hugr_roundtrip(&hugr);
check_hugr_schema_roundtrip(&hugr);
Ok(())
}

Expand All @@ -490,7 +503,7 @@ pub mod test {

let hugr = dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY)?;

check_hugr_roundtrip(&hugr);
check_hugr_schema_roundtrip(&hugr);
Ok(())
}

Expand All @@ -501,7 +514,7 @@ pub mod test {
let op = bldr.add_dataflow_op(Noop { ty: fn_ty }, bldr.input_wires())?;
let h = bldr.finish_prelude_hugr_with_outputs(op.outputs())?;

check_hugr_roundtrip(&h);
check_hugr_schema_roundtrip(&h);
Ok(())
}

Expand All @@ -519,7 +532,7 @@ pub mod test {
hugr.remove_node(old_in);
hugr.update_validate(&PRELUDE_REGISTRY)?;

let new_hugr: Hugr = check_hugr_roundtrip(&hugr);
let new_hugr: Hugr = check_hugr_schema_roundtrip(&hugr);
new_hugr.validate(&EMPTY_REG).unwrap_err();
new_hugr.validate(&PRELUDE_REGISTRY)?;
Ok(())
Expand Down
9 changes: 9 additions & 0 deletions hugr/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
// Hierarchy and children. No type variables declared outside the root.
self.validate_subtree(self.hugr.root(), &[])?;

// In tests we take the opportunity to verify that the hugr
// serialization round-trips.
//
// TODO: We should also verify that the serialized hugr matches the
// in-tree schema. For now, our serialized hugr does not match the
// schema. When this is fixed we should pass true below.
#[cfg(test)]
crate::hugr::serialize::test::check_hugr_roundtrip(self.hugr, false);

Ok(())
}

Expand Down
12 changes: 6 additions & 6 deletions hugr/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ impl OpType {
///
/// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports.
pub fn other_port(&self, dir: Direction) -> Option<Port> {
let df_count = self.value_port_count(dir);
let non_df_count = self.non_df_port_count(dir);
if self.other_port_kind(dir).is_some() && non_df_count == 1 {
// if there is a static input it comes before the non_df_ports
let static_input =
(dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize;

Some(Port::new(dir, self.value_port_count(dir) + static_input))
// if there is a static input it comes before the non_df_ports
let static_input =
(dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize;
if self.other_port_kind(dir).is_some() && non_df_count >= 1 {
Some(Port::new(dir, df_count + static_input))
} else {
None
}
Expand Down
6 changes: 5 additions & 1 deletion hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ mod test {

use super::*;

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// A custom constant value used in testing
pub(crate) struct CustomTestValue(pub CustomType);

Expand All @@ -320,6 +320,10 @@ mod test {
fn get_type(&self) -> Type {
self.0.clone().into()
}

fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
}
}

/// A [`CustomSerialized`] encoding a [`FLOAT64_TYPE`] float constant used in testing.
Expand Down