Skip to content

Commit

Permalink
refactor: move hugr equality check out for reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Dec 21, 2023
1 parent ff26546 commit 64b9199
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 37 deletions.
44 changes: 42 additions & 2 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,16 @@ pub enum HugrError {
}

#[cfg(test)]
mod test {
pub(crate) mod test {
use itertools::Itertools;
use portgraph::{LinkView, PortView};

use super::{Hugr, HugrView};
use crate::builder::test::closed_dfg_root_hugr;
use crate::extension::ExtensionSet;
use crate::hugr::HugrMut;
use crate::ops;
use crate::ops::LeafOp;
use crate::ops::{self, OpType};
use crate::type_row;
use crate::types::{FunctionType, Type};

Expand Down Expand Up @@ -398,4 +402,40 @@ mod test {
assert_eq!(hugr.get_nodetype(output).input_extensions().unwrap(), &r);
Ok(())
}

pub(crate) fn assert_hugr_equality(hugr: &Hugr, other: &Hugr) {
assert_eq!(other.root, hugr.root);
assert_eq!(other.hierarchy, hugr.hierarchy);
assert_eq!(other.metadata, hugr.metadata);

// Extension operations may have been downgraded to opaque operations.
for node in other.nodes() {
let new_op = other.get_optype(node);
let old_op = hugr.get_optype(node);
if let OpType::LeafOp(LeafOp::CustomOp(new_ext)) = new_op {
if let OpType::LeafOp(LeafOp::CustomOp(old_ext)) = old_op {
assert_eq!(new_ext.clone().as_opaque(), old_ext.clone().as_opaque());
} else {
panic!("Expected old_op to be a custom op");
}
} else {
assert_eq!(new_op, old_op);
}
}

// Check that the graphs are equivalent up to port renumbering.
let new_graph = &other.graph;
let old_graph = &hugr.graph;
assert_eq!(new_graph.node_count(), old_graph.node_count());
assert_eq!(new_graph.port_count(), old_graph.port_count());
assert_eq!(new_graph.link_count(), old_graph.link_count());
for n in old_graph.nodes_iter() {
assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n));
assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n));
assert_eq!(
new_graph.output_neighbours(n).collect_vec(),
old_graph.output_neighbours(n).collect_vec()
);
}
}
}
37 changes: 2 additions & 35 deletions src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,14 @@ pub mod test {
use crate::extension::simple_op::MakeRegisteredOp;
use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::test::assert_hugr_equality;
use crate::hugr::NodeType;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{dataflow::IOTrait, Input, LeafOp, Module, Output, DFG};
use crate::std_extensions::logic::NotOp;
use crate::types::{FunctionType, Type};
use crate::OutgoingPort;
use itertools::Itertools;
use portgraph::LinkView;
use portgraph::{
multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, PortView, UnmanagedDenseMap,
};
Expand Down Expand Up @@ -298,40 +298,7 @@ pub mod test {
// The internal port indices may still be different.
let mut h_canon = hugr.clone();
h_canon.canonicalize_nodes(|_, _| {});

assert_eq!(new_hugr.root, h_canon.root);
assert_eq!(new_hugr.hierarchy, h_canon.hierarchy);
assert_eq!(new_hugr.metadata, h_canon.metadata);

// Extension operations may have been downgraded to opaque operations.
for node in new_hugr.nodes() {
let new_op = new_hugr.get_optype(node);
let old_op = h_canon.get_optype(node);
if let OpType::LeafOp(LeafOp::CustomOp(new_ext)) = new_op {
if let OpType::LeafOp(LeafOp::CustomOp(old_ext)) = old_op {
assert_eq!(new_ext.clone().as_opaque(), old_ext.clone().as_opaque());
} else {
panic!("Expected old_op to be a custom op");
}
} else {
assert_eq!(new_op, old_op);
}
}

// Check that the graphs are equivalent up to port renumbering.
let new_graph = &new_hugr.graph;
let old_graph = &h_canon.graph;
assert_eq!(new_graph.node_count(), old_graph.node_count());
assert_eq!(new_graph.port_count(), old_graph.port_count());
assert_eq!(new_graph.link_count(), old_graph.link_count());
for n in old_graph.nodes_iter() {
assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n));
assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n));
assert_eq!(
new_graph.output_neighbours(n).collect_vec(),
old_graph.output_neighbours(n).collect_vec()
);
}
assert_hugr_equality(&h_canon, &new_hugr);

new_hugr
}
Expand Down

0 comments on commit 64b9199

Please sign in to comment.