Skip to content

Commit

Permalink
feat: Utilities for loading compiled guppy circuits (#393)
Browse files Browse the repository at this point in the history
Adds a series of `load_guppy_json*` methods similar to `load_tk1_json`.
Given a hugr json and a function name it returns a `Circuit` for the
function.

Currently this only supports guppy functions with no control flow.

Adds a test adapted from #382, with the following hugr:
```mermaid
graph LR
    subgraph 0 ["(0) Module"]
        direction LR
        subgraph 7 ["(7) FuncDefn"]
            direction LR
            3["(3) Input"]
            3--"0:0<br>qubit"-->8
            3--"1:1<br>qubit"-->8
            6["(6) Output"]
            subgraph 8 ["(8) CFG"]
                direction LR
                subgraph 1 ["(1) DataflowBlock"]
                    direction LR
                    4["(4) Input"]
                    4--"0:0<br>qubit"-->13
                    4--"1:0<br>qubit"-->21
                    5["(5) Output"]
                    9["(9) const:custom:f64(1.5707963267948966)"]
                    9--"0:0<br>float64"-->10
                    10["(10) LoadConstant"]
                    10--"0:1<br>float64"-->13
                    11["(11) const:custom:f64(-1.5707963267948966)"]
                    11--"0:0<br>float64"-->12
                    12["(12) LoadConstant"]
                    12--"0:2<br>float64"-->13
                    13["(13) quantum.tket2.PhasedX"]
                    13--"0:0<br>qubit"-->16
                    14["(14) const:custom:f64(3.141592653589793)"]
                    14--"0:0<br>float64"-->15
                    15["(15) LoadConstant"]
                    15--"0:1<br>float64"-->16
                    16["(16) quantum.tket2.RzF64"]
                    16--"0:0<br>qubit"-->25
                    17["(17) const:custom:f64(1.5707963267948966)"]
                    17--"0:0<br>float64"-->18
                    18["(18) LoadConstant"]
                    18--"0:1<br>float64"-->21
                    19["(19) const:custom:f64(-1.5707963267948966)"]
                    19--"0:0<br>float64"-->20
                    20["(20) LoadConstant"]
                    20--"0:2<br>float64"-->21
                    21["(21) quantum.tket2.PhasedX"]
                    21--"0:0<br>qubit"-->24
                    22["(22) const:custom:f64(3.141592653589793)"]
                    22--"0:0<br>float64"-->23
                    23["(23) LoadConstant"]
                    23--"0:1<br>float64"-->24
                    24["(24) quantum.tket2.RzF64"]
                    24--"0:1<br>qubit"-->25
                    25["(25) quantum.tket2.ZZMax"]
                    25--"0:0<br>qubit"-->26
                    25--"1:1<br>qubit"-->26
                    26["(26) MakeTuple"]
                    26--"0:0<br>[qubit, qubit]"-->27
                    27["(27) UnpackTuple"]
                    27--"0:0<br>qubit"-->28
                    27--"1:0<br>qubit"-->30
                    28["(28) quantum.tket2.Measure"]
                    28--"0:0<br>qubit"-->29
                    29["(29) quantum.tket2.QFree"]
                    30["(30) quantum.tket2.Measure"]
                    30--"0:0<br>qubit"-->31
                    30--"1:0<br>[]+[]"-->32
                    31["(31) quantum.tket2.QFree"]
                    32["(32) MakeTuple"]
                    32--"0:0<br>[[]+[]]"-->33
                    33["(33) UnpackTuple"]
                    33--"0:1<br>[]+[]"-->5
                    34["(34) Tag"]
                    34--"0:0<br>[]"-->5
                end
                1-."0:0".->2
                2["(2) ExitBlock"]
            end
            8--"0:0<br>[]+[]"-->6
        end
    end
```

drive-by: Drop deprecated `stringreader` dependency
drive-by: Bind `Tk2Circuit.num_operations`, used in the python test
  • Loading branch information
aborgna-q committed Jun 11, 2024
1 parent 93e611c commit 028779a
Show file tree
Hide file tree
Showing 12 changed files with 447 additions and 92 deletions.
7 changes: 0 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ serde = "1.0"
serde_json = "1.0"
serde_yaml = "0.9.22"
smol_str = "0.2.0"
stringreader = "0.1.1"
strum = "0.26.1"
strum_macros = "0.26.4"
thiserror = "1.0.28"
Expand Down
298 changes: 224 additions & 74 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ readme = "README.md"
packages = [{ include = "tket2-py" }]

[tool.poetry.dependencies]
python = ">=3.10"
python = "^3.10"
pytket = "1.28.0"

[tool.poetry.group.dev.dependencies]
Expand All @@ -36,6 +36,7 @@ mypy = "^1.9.0"
hypothesis = "^6.103.1"
graphviz = "^0.20"
pre-commit = "^3.7.1"
guppylang = "^0.5.0"

[build-system]
requires = ["maturin~=1.5.1"]
Expand Down
27 changes: 22 additions & 5 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ impl Tk2Circuit {
Ok(Tk2Circuit { circ: hugr.into() })
}

/// Load a function from a compiled guppy module, encoded as a json string.
#[staticmethod]
pub fn from_guppy_json(json: &str, function: &str) -> PyResult<Self> {
let circ = tket2::serialize::load_guppy_json_str(json, function).map_err(|e| {
PyErr::new::<PyAttributeError, _>(format!("Invalid encoded circuit: {e}"))
})?;
Ok(Tk2Circuit { circ })
}

/// Encode the circuit as a tket1 json string.
pub fn to_tket1_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&SerialCircuit::encode(&self.circ).convert_pyerrs()?).unwrap())
Expand All @@ -106,11 +115,10 @@ impl Tk2Circuit {
/// Decode a tket1 json string to a circuit.
#[staticmethod]
pub fn from_tket1_json(json: &str) -> PyResult<Self> {
let tk1: SerialCircuit = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit {
circ: tk1.decode().convert_pyerrs()?,
})
let circ = tket2::serialize::load_tk1_json_str(json).map_err(|e| {
PyErr::new::<PyAttributeError, _>(format!("Could not load pytket circuit: {e}"))
})?;
Ok(Tk2Circuit { circ })
}

/// Compute the cost of the circuit based on a per-operation cost function.
Expand Down Expand Up @@ -138,6 +146,15 @@ impl Tk2Circuit {
Ok(circ_cost.cost.into_bound(py))
}

/// Returns the number of operations in the circuit.
///
/// This includes [`Tk2Op`]s, pytket ops, and any other custom operations.
///
/// Nested circuits are traversed to count their operations.
pub fn num_operations(&self) -> usize {
self.circ.num_operations()
}

/// Returns a hash of the circuit.
pub fn hash(&self) -> u64 {
self.circ.circuit_hash().unwrap()
Expand Down
39 changes: 39 additions & 0 deletions tket2-py/test/test_guppy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import no_type_check
from tket2.circuit import Tk2Circuit

import math

from guppylang import guppy
from guppylang.module import GuppyModule
from guppylang.prelude import quantum
from guppylang.prelude.builtins import py
from guppylang.prelude.quantum import measure, phased_x, qubit, rz, zz_max


def test_load_compiled_module():
module = GuppyModule("test")
module.load(quantum)

@guppy(module)
@no_type_check
def my_func(
q0: qubit,
q1: qubit,
) -> tuple[bool,]:
q0 = phased_x(q0, py(math.pi / 2), py(-math.pi / 2))
q0 = rz(q0, py(math.pi))
q1 = phased_x(q1, py(math.pi / 2), py(-math.pi / 2))
q1 = rz(q1, py(math.pi))
q0, q1 = zz_max(q0, q1)
_ = measure(q0)
return (measure(q1),)

# Compile the module, and convert it to a JSON string
hugr = module.compile()
json = hugr.to_raw().to_json()

# Load the module from the JSON string
circ = Tk2Circuit.from_guppy_json(json, "my_func")

# The 7 operations in the function, plus two implicit QFree
assert circ.num_operations() == 9
12 changes: 12 additions & 0 deletions tket2-py/tket2/_tket2/circuit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class Tk2Circuit:
def circuit_cost(self, cost_fn: Callable[[Tk2Op], Any]) -> int:
"""Compute the cost of the circuit. Return value must implement __add__."""

def num_operations(self) -> int:
"""The number of operations in the circuit.
This includes [`Tk2Op`]s, pytket ops, and any other custom operations.
Nested circuits are traversed to count their operations.
"""

def node_op(self, node: Node) -> CustomOp:
"""If the node corresponds to a custom op, return it. Otherwise, raise an error."""

Expand Down Expand Up @@ -55,6 +63,10 @@ class Tk2Circuit:
def to_tket1_json(self) -> str:
"""Encode the circuit as a pytket json string."""

@staticmethod
def from_guppy_json(json: str, function: str) -> Tk2Circuit:
"""Load a function from a compiled guppy module, encoded as a json string."""

@staticmethod
def from_tket1_json(json: str) -> Tk2Circuit:
"""Decode a pytket json string to a Tk2Circuit."""
Expand Down
1 change: 0 additions & 1 deletion tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ delegate = { workspace = true }
csv = { workspace = true }
chrono = { workspace = true }
bytemuck = { workspace = true }
stringreader = { workspace = true }
crossbeam-channel = { workspace = true }
tracing = { workspace = true }

Expand Down
2 changes: 1 addition & 1 deletion tket2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ pub mod portmatching;

mod utils;

pub use circuit::Circuit;
pub use circuit::{Circuit, CircuitError, CircuitMutError};
pub use hugr::Hugr;
pub use ops::{op_matches, symbolic_constant_op, Pauli, Tk2Op};
4 changes: 4 additions & 0 deletions tket2/src/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
//! Utilities for serializing circuits.
//!
//! See [`crate::serialize::pytket`] for serialization to and from the legacy pytket format.
pub mod guppy;
pub mod pytket;

pub use guppy::{
load_guppy_json_file, load_guppy_json_reader, load_guppy_json_str, CircuitLoadError,
};
pub use pytket::{
load_tk1_json_file, load_tk1_json_reader, load_tk1_json_str, save_tk1_json_file,
save_tk1_json_str, save_tk1_json_writer, TKETDecode,
Expand Down
142 changes: 142 additions & 0 deletions tket2/src/serialize/guppy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
//! Load pre-compiled guppy functions.

use std::path::Path;
use std::{fs, io};

use hugr::ops::{NamedOp, OpTag, OpTrait, OpType};
use hugr::{Hugr, HugrView};
use itertools::Itertools;
use thiserror::Error;

use crate::{Circuit, CircuitError};

/// Loads a pre-compiled guppy file.
pub fn load_guppy_json_file(
path: impl AsRef<Path>,
function: &str,
) -> Result<Circuit, CircuitLoadError> {
let file = fs::File::open(path)?;
let reader = io::BufReader::new(file);
load_guppy_json_reader(reader, function)
}

/// Loads a pre-compiled guppy file from a json string.
pub fn load_guppy_json_str(json: &str, function: &str) -> Result<Circuit, CircuitLoadError> {
let reader = json.as_bytes();
load_guppy_json_reader(reader, function)
}

/// Loads a pre-compiled guppy file from a reader.
pub fn load_guppy_json_reader(
reader: impl io::Read,
function: &str,
) -> Result<Circuit, CircuitLoadError> {
let hugr: Hugr = serde_json::from_reader(reader)?;
find_function(hugr, function)
}

/// Looks for the required function in a HUGR compiled from a guppy module.
///
/// Guppy functions are compiled into a root module, with each function as a `FuncDecl` child.
/// Each `FuncDecl` contains a `CFG` operation that defines the function.
///
/// Currently we only support functions where the CFG operation has a single `DataflowBlock` child,
/// which we use as the root of the circuit. We (currently) do not support control flow primitives.
///
/// # Errors
///
/// - If the root of the HUGR is not a module operation.
/// - If the function is not found in the module.
/// - If the function has control flow primitives.
fn find_function(hugr: Hugr, function_name: &str) -> Result<Circuit, CircuitLoadError> {
// Find the root module.
let module = hugr.root();
if !OpTag::ModuleRoot.is_superset(hugr.get_optype(module).tag()) {
return Err(CircuitLoadError::NonModuleRoot {
root_op: hugr.get_optype(module).clone(),
});
}

// Find the function declaration.
fn func_name(op: &OpType) -> &str {
match op {
OpType::FuncDefn(decl) => &decl.name,
_ => "",
}
}

let Some(function) = hugr
.children(module)
.find(|&n| func_name(hugr.get_optype(n)) == function_name)
else {
let available_functions = hugr
.children(module)
.map(|n| func_name(hugr.get_optype(n)).to_string())
.collect();
return Err(CircuitLoadError::FunctionNotFound {
function: function_name.to_string(),
available_functions,
});
};

// Find the CFG operation.
let invalid_cfg = CircuitLoadError::InvalidControlFlow {
function: function_name.to_string(),
};
let Ok(cfg) = hugr.children(function).skip(2).exactly_one() else {
return Err(invalid_cfg);
};

// Find the single dataflow block to use as the root of the circuit.
// The cfg node should only have the dataflow block and an exit node as children.
let mut cfg_children = hugr.children(cfg);
let Some(dataflow) = cfg_children.next() else {
return Err(invalid_cfg);
};
if cfg_children.nth(1).is_some() {
return Err(invalid_cfg);
}

let circ = Circuit::try_new(hugr, dataflow)?;
Ok(circ)
}

/// Error type for conversion between `Op` and `OpType`.
#[derive(Debug, Error)]
pub enum CircuitLoadError {
/// Cannot load the circuit file.
#[error("Cannot load the circuit file: {0}")]
InvalidFile(#[from] io::Error),
/// Invalid JSON
#[error("Invalid JSON. {0}")]
InvalidJson(#[from] serde_json::Error),
/// The root node is not a module operation.
#[error(
"Expected a HUGR with a module at the root, but found a {} instead.",
root_op.name()
)]
NonModuleRoot {
/// The root operation.
root_op: OpType,
},
/// The function is not found in the module.
#[error(
"Function '{function}' not found in the loaded module. Available functions: [{}]",
available_functions.join(", ")
)]
FunctionNotFound {
/// The function name.
function: String,
/// The available functions.
available_functions: Vec<String>,
},
/// The function has an invalid control flow structure.
#[error("Function '{function}' has an invalid control flow structure. Currently only flat functions with no control flow primitives are supported.")]
InvalidControlFlow {
/// The function name.
function: String,
},
/// Error loading the circuit.
#[error("Error loading the circuit: {0}")]
CircuitLoadError(#[from] CircuitError),
}
3 changes: 1 addition & 2 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use std::{fs, io};
use hugr::ops::{OpType, Value};
use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};

use stringreader::StringReader;
use thiserror::Error;
use tket_json_rs::circuit_json::SerialCircuit;
use tket_json_rs::optype::OpType as JsonOpType;
Expand Down Expand Up @@ -120,7 +119,7 @@ pub fn load_tk1_json_reader(json: impl io::Read) -> Result<Circuit, TK1ConvertEr

/// Load a TKET1 circuit from a JSON string.
pub fn load_tk1_json_str(json: &str) -> Result<Circuit, TK1ConvertError> {
let reader = StringReader::new(json);
let reader = json.as_bytes();
load_tk1_json_reader(reader)
}

Expand Down

0 comments on commit 028779a

Please sign in to comment.