Skip to content

Commit

Permalink
fix: Check for rewrite composition in badger (#255)
Browse files Browse the repository at this point in the history
This is a short-term quick-but-incomplete fix for #239.

We now check right after the composition whether we have invalidated the
circuit by making a loop.
We leverage that circuit hashing already does a toposort traversal,
adding an error when a loop is found and catching it.

Still, it is theoretically possible for an invalid chain of rewrites to
end up producing a valid (but not equivalent) circuit. There is a
discussion open for discussing more general solutions to this problem
#242.
  • Loading branch information
aborgna-q committed Nov 21, 2023
1 parent b077660 commit 0b793be
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 18 deletions.
Binary file added test_files/nam_6_3.rwr
Binary file not shown.
34 changes: 22 additions & 12 deletions tket2/src/circuit/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::ops::{LeafOp, OpName, OpType};
use hugr::{HugrView, Node};
use petgraph::visit::{self as pg, Walker};
use thiserror::Error;

use super::Circuit;

Expand All @@ -22,29 +23,30 @@ pub trait CircuitHash<'circ>: HugrView {
///
/// Adapted from Quartz (Apache 2.0)
/// <https://github.com/quantum-compiler/quartz/blob/2e13eb7ffb3c5c5fe96cf5b4246f4fd7512e111e/src/quartz/tasograph/tasograph.cpp#L410>
fn circuit_hash(&'circ self) -> u64;
fn circuit_hash(&'circ self) -> Result<u64, HashError>;
}

impl<'circ, T> CircuitHash<'circ> for T
where
T: HugrView,
{
fn circuit_hash(&'circ self) -> u64 {
fn circuit_hash(&'circ self) -> Result<u64, HashError> {
let mut node_hashes = HashState::default();

for node in pg::Topo::new(&self.as_petgraph())
.iter(&self.as_petgraph())
.filter(|&n| n != self.root())
{
let hash = hash_node(self, node, &mut node_hashes);
let hash = hash_node(self, node, &mut node_hashes)?;
if node_hashes.set_hash(node, hash).is_some() {
panic!("Hash already set for node {node}");
}
}

// If the output node has no hash, the topological sort failed due to a cycle.
node_hashes
.node_hash(self.output())
.expect("Output hash has not been set")
.ok_or(HashError::CyclicCircuit)
}
}

Expand Down Expand Up @@ -95,14 +97,14 @@ fn hashable_op(op: &OpType) -> impl Hash {
/// # Panics
/// - If the command is a container node, or if it is a parametric CustomOp.
/// - If the hash of any of its predecessors has not been set.
fn hash_node(circ: &impl HugrView, node: Node, state: &mut HashState) -> u64 {
fn hash_node(circ: &impl HugrView, node: Node, state: &mut HashState) -> Result<u64, HashError> {
let op = circ.get_optype(node);
let mut hasher = FxHasher64::default();

// Hash the node children
if circ.children(node).count() > 0 {
let container: SiblingGraph = SiblingGraph::try_new(circ, node).unwrap();
container.circuit_hash().hash(&mut hasher);
container.circuit_hash()?.hash(&mut hasher);
}

// Hash the node operation
Expand All @@ -121,7 +123,15 @@ fn hash_node(circ: &impl HugrView, node: Node, state: &mut HashState) -> u64 {
.fold(0, |total, hash| hash ^ total);
input_hash.hash(&mut hasher);
}
hasher.finish()
Ok(hasher.finish())
}

/// Errors that can occur while hashing a hugr.
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum HashError {
/// The circuit contains a cycle.
#[error("The circuit contains a cycle.")]
CyclicCircuit,
}

#[cfg(test)]
Expand All @@ -144,7 +154,7 @@ mod test {
Ok(())
})
.unwrap();
let hash1 = circ1.circuit_hash();
let hash1 = circ1.circuit_hash().unwrap();

// A circuit built in a different order should have the same hash
let circ2 = build_simple_circuit(2, |circ| {
Expand All @@ -154,7 +164,7 @@ mod test {
Ok(())
})
.unwrap();
let hash2 = circ2.circuit_hash();
let hash2 = circ2.circuit_hash().unwrap();

assert_eq!(hash1, hash2);

Expand All @@ -166,7 +176,7 @@ mod test {
Ok(())
})
.unwrap();
let hash3 = circ3.circuit_hash();
let hash3 = circ3.circuit_hash().unwrap();

assert_ne!(hash1, hash3);
}
Expand All @@ -176,7 +186,7 @@ mod test {
let c_str = r#"{"bits": [], "commands": [{"args": [["q", [0]]], "op": {"params": ["0.5"], "type": "Rz"}}], "created_qubits": [], "discarded_qubits": [], "implicit_permutation": [[["q", [0]], ["q", [0]]]], "phase": "0.0", "qubits": [["q", [0]]]}"#;
let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap();
let circ: Hugr = ser.decode().unwrap();
circ.circuit_hash();
circ.circuit_hash().unwrap();
}

#[test]
Expand All @@ -188,7 +198,7 @@ mod test {
for c_str in [c_str1, c_str2] {
let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap();
let circ: Hugr = ser.decode().unwrap();
all_hashes.push(circ.circuit_hash());
all_hashes.push(circ.circuit_hash().unwrap());
}
assert_ne!(all_hashes[0], all_hashes[1]);
}
Expand Down
69 changes: 65 additions & 4 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,19 @@ where
logger.log_best(&best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let hash = circ.circuit_hash().unwrap();
let mut seen_hashes = FxHashSet::default();
seen_hashes.insert(circ.circuit_hash());
seen_hashes.insert(hash);

// The priority queue of circuits to be processed (this should not get big)
let cost_fn = {
let strategy = self.strategy.clone();
move |circ: &'_ Hugr| strategy.circuit_cost(circ)
};
let cost = (cost_fn)(circ);

let mut pq = HugrPQ::new(cost_fn, queue_size);
pq.push(circ.clone());
pq.push_unchecked(circ.clone(), hash, cost);

let mut circ_cnt = 0;
let mut timeout_flag = false;
Expand All @@ -169,7 +172,13 @@ where
continue;
}

let new_circ_hash = new_circ.circuit_hash();
let Ok(new_circ_hash) = new_circ.circuit_hash() else {
// The composed rewrites produced a loop.
//
// See [https://github.com/CQCL/tket2/discussions/242]
continue;
};

if !seen_hashes.insert(new_circ_hash) {
// Ignore this circuit: we've already seen it
continue;
Expand Down Expand Up @@ -218,7 +227,7 @@ where
};
let (pq, rx_log) = HugrPriorityChannel::init(cost_fn.clone(), queue_size);

let initial_circ_hash = circ.circuit_hash();
let initial_circ_hash = circ.circuit_hash().unwrap();
let mut best_circ = circ.clone();
let mut best_circ_cost = self.cost(&best_circ);

Expand Down Expand Up @@ -436,6 +445,7 @@ mod tests {
};
use rstest::{fixture, rstest};

use crate::json::load_tk1_json_str;
use crate::{extension::REGISTRY, Circuit, T2Op};

use super::{BadgerOptimiser, DefaultBadgerOptimiser};
Expand Down Expand Up @@ -466,11 +476,45 @@ mod tests {
h.finish_hugr_with_outputs([qb], &REGISTRY).unwrap()
}

/// This hugr corresponds to the qasm circuit:
///
/// ```skip
/// OPENQASM 2.0;
/// include "qelib1.inc";
///
/// qreg q[5];
/// cx q[4],q[1];
/// cx q[3],q[4];
/// cx q[1],q[2];
/// cx q[4],q[0];
/// u3(0.5*pi,0.0*pi,0.5*pi) q[1];
/// cx q[0],q[2];
/// cx q[3],q[1];
/// cx q[0],q[2];
/// ```
const NON_COMPOSABLE: &str = r#"{"phase":"0.0","commands":[{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[1]],["q",[2]]]},{"op":{"type":"U3","params":["0.5","0","0.5"],"signature":["Q"]},"args":[["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[4]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[0]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[1]]]}],"qubits":[["q",[0]],["q",[1]],["q",[2]],["q",[3]],["q",[4]]],"bits":[],"implicit_permutation":[[["q",[0]],["q",[0]]],[["q",[1]],["q",[1]]],[["q",[2]],["q",[2]]],[["q",[3]],["q",[3]]],[["q",[4]],["q",[4]]]]}"#;

/// A Hugr that would trigger non-composable rewrites, if we applied them blindly from nam_6_3 matches.
#[fixture]
fn non_composable_rw_hugr() -> Hugr {
load_tk1_json_str(NON_COMPOSABLE).unwrap()
}

/// A badger optimiser using a reduced set of rewrite rules.
#[fixture]
fn badger_opt() -> DefaultBadgerOptimiser {
BadgerOptimiser::default_with_eccs_json_file("../test_files/small_eccs.json").unwrap()
}

/// A badger optimiser using the complete nam_6_3 rewrite set.
///
/// NOTE: This takes a few seconds to load.
/// Use [`badger_opt`] if possible.
#[fixture]
fn badger_opt_full() -> DefaultBadgerOptimiser {
BadgerOptimiser::default_with_rewriter_binary("../test_files/nam_6_3.rwr").unwrap()
}

#[rstest]
fn rz_rz_cancellation(rz_rz: Hugr, badger_opt: DefaultBadgerOptimiser) {
let opt_rz = badger_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false, 4);
Expand All @@ -483,4 +527,21 @@ mod tests {
let mut opt_rz = badger_opt.optimise(&rz_rz, Some(0), 2.try_into().unwrap(), false, 4);
opt_rz.update_validate(&REGISTRY).unwrap();
}

#[rstest]
#[ignore = "Loading the ECC set is really slow (~5 seconds)"]
fn non_composable_rewrites(
non_composable_rw_hugr: Hugr,
badger_opt_full: DefaultBadgerOptimiser,
) {
let mut opt = badger_opt_full.optimise(
&non_composable_rw_hugr,
Some(0),
1.try_into().unwrap(),
false,
10,
);
// No rewrites applied.
opt.update_validate(&REGISTRY).unwrap();
}
}
2 changes: 1 addition & 1 deletion tket2/src/optimiser/badger/hugr_pqueue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl<P: Ord, C> HugrPQ<P, C> {
where
C: Fn(&Hugr) -> P,
{
let hash = hugr.circuit_hash();
let hash = hugr.circuit_hash().unwrap();
let cost = (self.cost_fn)(&hugr);
self.push_unchecked(hugr, hash, cost);
}
Expand Down
8 changes: 7 additions & 1 deletion tket2/src/optimiser/badger/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ where
return None;
}

let hash = c.circuit_hash();
let Ok(hash) = c.circuit_hash() else {
// The composed rewrites were not valid.
//
// See [https://github.com/CQCL/tket2/discussions/242]
return None;
};

Some(Work {
cost: new_cost,
hash,
Expand Down

0 comments on commit 0b793be

Please sign in to comment.