Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Scalar serialization and pyzx interop #33

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,10 @@ repos:
language: system
files: \.py$
pass_filenames: false
- id: cargo-doc
name: cargo doc
description: Generate documentation with `cargo doc`.
entry: cargo doc --no-deps --all-features --workspace
language: system
files: \.rs$
pass_filenames: false
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pybindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ crate-type = ["cdylib"]
[dependencies]
quizx = { workspace = true }
num = { workspace = true }
pyo3 = { workspace = true, features = ["extension-module"] }
pyo3 = { workspace = true, features = ["extension-module", "num-complex"] }
derive_more = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
6 changes: 3 additions & 3 deletions pybindings/quizx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from . import _quizx
from . import _quizx, simplify
from .graph import VecGraph
from .circuit import Circuit
from . import simplify
from .decompose import Decomposer
from ._quizx import Scalar

__all__ = ["VecGraph", "Circuit", "simplify", "Decomposer"]
__all__ = ["VecGraph", "Circuit", "simplify", "Decomposer", "Scalar"]


def extract_circuit(g):
Expand Down
31 changes: 31 additions & 0 deletions pybindings/quizx/_quizx.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
# Type stubs for the Rust bindings

from typing import final
from builtins import complex as complex_

@final
class Scalar:
@staticmethod
def complex(complex: complex_) -> Scalar: ...
@staticmethod
def real(real: float) -> Scalar: ...
@staticmethod
def from_phase(phase: float) -> Scalar: ...
@staticmethod
def from_int_coeffs(coeffs: list[int]) -> Scalar: ...
@staticmethod
def one_plus_phase(phase: float) -> Scalar: ...
@staticmethod
def sqrt2_pow(n: int) -> Scalar: ...
def complex_value(self) -> complex_: ...
def mul_sqrt2_pow(self, n: int) -> Scalar: ...
def mul_phase(self, phase: float) -> Scalar: ...
@staticmethod
def zero() -> Scalar: ...
@staticmethod
def one() -> Scalar: ...
def is_zero(self) -> bool: ...
def is_one(self) -> bool: ...
def conjugate(self) -> Scalar: ...
def to_json(self) -> str: ...
@staticmethod
def from_json(json: str) -> Scalar: ...

@final
class VecGraph:
scalar: Scalar
def vindex(self) -> int: ...
def neighbor_at(self, v: int, n: int) -> int: ...
def num_vertices(self) -> int: ...
Expand Down Expand Up @@ -60,6 +90,7 @@ class CircuitStats:

@final
class Decomposer:
scalar: Scalar
@staticmethod
def empty() -> Decomposer: ...
def __init__(self, g: VecGraph) -> None: ...
Expand Down
10 changes: 10 additions & 0 deletions pybindings/quizx/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from . import _quizx
from .graph import VecGraph
from .scalar import from_pyzx_scalar, to_pyzx_scalar
from pyzx.graph.scalar import Scalar


class Decomposer(object):
Expand Down Expand Up @@ -30,3 +32,11 @@ def decomp_all(self):

def decomp_until_depth(self, depth: int):
self._d.decomp_until_depth(depth)

@property
def scalar(self) -> Scalar:
return to_pyzx_scalar(self._d.scalar)

@scalar.setter
def scalar(self, s: Scalar):
self._d.scalar = from_pyzx_scalar(s)
10 changes: 10 additions & 0 deletions pybindings/quizx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .scalar import from_pyzx_scalar, to_pyzx_scalar
from fractions import Fraction
from typing import Tuple, Dict, Any, Optional
from pyzx.graph.base import BaseGraph # type: ignore
from pyzx.utils import VertexType, EdgeType # type: ignore
from pyzx.graph.scalar import Scalar

from . import _quizx

Expand Down Expand Up @@ -315,3 +317,11 @@ def num_outputs(self):

def set_outputs(self, outputs):
self._g.set_outputs(list(outputs))

@property
def scalar(self) -> Scalar:
return to_pyzx_scalar(self._g.scalar)

@scalar.setter
def scalar(self, s: Scalar):
self._g.scalar = from_pyzx_scalar(s)
14 changes: 14 additions & 0 deletions pybindings/quizx/scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Utilities for interfacing `quizx::Scalar` with `pyzx.Scalar`."""

from ._quizx import Scalar as QuizxScalar
from pyzx.graph.scalar import Scalar as PyzxScalar


def from_pyzx_scalar(s: PyzxScalar) -> QuizxScalar:
"""Convert a `pyzx.Scalar` to a `quizx::Scalar`."""
return QuizxScalar.from_json(s.to_json())


def to_pyzx_scalar(s: QuizxScalar) -> PyzxScalar:
"""Convert a `pyzx.Scalar` to a `quizx::Scalar`."""
return PyzxScalar.from_json(s.to_json())
27 changes: 27 additions & 0 deletions pybindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
pub mod scalar;

use crate::scalar::Scalar;

use num::Rational64;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
Expand All @@ -16,6 +20,7 @@ fn _quizx(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Circuit>()?;
m.add_class::<CircuitStats>()?;
m.add_class::<Decomposer>()?;
m.add_class::<Scalar>()?;
Ok(())
}

Expand Down Expand Up @@ -290,6 +295,18 @@ impl VecGraph {
fn set_outputs(&mut self, outputs: Vec<V>) {
self.g.set_outputs(outputs)
}

/// Returns the graph scalar.
#[getter]
fn get_scalar(&self) -> Scalar {
self.g.scalar().clone().into()
}

/// Sets the graph scalar.
#[setter]
fn set_scalar(&mut self, scalar: Scalar) {
*self.g.scalar_mut() = scalar.into();
}
}

#[pyclass]
Expand All @@ -313,6 +330,16 @@ impl Decomposer {
}
}

#[getter]
fn get_scalar(&self) -> Scalar {
self.d.scalar.clone().into()
}

#[setter]
fn set_scalar(&mut self, scalar: Scalar) {
self.d.scalar = scalar.into();
}

fn graphs(&self) -> PyResult<Vec<VecGraph>> {
let mut gs = vec![];
for (_a, g) in &self.d.stack {
Expand Down
Loading
Loading