Skip to content

Commit

Permalink
fix: Fix bug in JsonScalar::from_scalar (#74)
Browse files Browse the repository at this point in the history
As @ColbyDeLisle mentioned in issue #63, `pow` refers to powers of two,
whereas pyzx's `power2` refers to powers of sqrt(2). This was handled
incorrectly in the Exact arm of `JsonScalar::from_scalar`. I have
reworked that now so this fixes #63.

I believe that `full_simplify` usually leads to scalars of the form
`sqrt(2)^n * exp(m*i*pi/4)` for Clifford+T graphs. I refactored the
match arm so that these values are always treated exactly. For other
scalars, they are simply split up as `phase` and `floatfactor`.
  • Loading branch information
rafaelha authored Oct 17, 2024
1 parent 91aea91 commit 9a6c6cb
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
39 changes: 37 additions & 2 deletions quizx/src/json/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//!
//! This definition is compatible with the `pyzx` JSON format for scalars.
use num::complex::ComplexFloat;
use std::f64::consts::PI;

use num::{One, Zero};
Expand All @@ -20,6 +21,7 @@ impl JsonScalar {
let phase_options = PhaseOptions {
ignore_approx: true,
ignore_pi: true,
limit_denom: Some(256),
..Default::default()
};
match scalar {
Expand All @@ -35,11 +37,39 @@ impl JsonScalar {
..Default::default()
}
}
Scalar::Exact(pow, _) => {
Scalar::Exact(pow, coeffs) => {
// pow is an integer specifying the power of 2 that is applied
// power2 in the JsonScalar representation and in pyzx refers to the power of sqrt(2)

// Extract the phase. scalar.phase() will return exact representations of multiples of pi/4. In
// other cases, we lose precision.
let phase = JsonPhase::from_phase(scalar.phase(), phase_options);

// In the Clifford+T case where we have Scalar4, we can extract factors of sqrt(2) directly from the
// coefficients. Since the coefficients are reduced, sqrt(2) is represented as
// [1, 0, +-1, 0], [0, 1, +-1, 0], where the +- lead to phase contributions already extracted in `phase`
let (power_sqrt2, floatfactor) =
match coeffs.iter_coeffs().collect::<Vec<_>>().as_slice() {
[a, 0, b, 0] | [0, a, 0, b]
if a.abs() == 1 && b.abs() == 1 && coeffs.len() == 4 =>
{
(*pow * 2 + 1, Default::default()) // Coefficients represent a factor of sqrt(2)
}
cf => (
// In all other cases, we simply assign the complex value to the pyzx floatfactor
*pow * 2,
Scalar::<Vec<_>>::from_int_coeffs(cf).complex_value().abs(),
),
};

JsonScalar {
power2: *pow,
power2: power_sqrt2,
phase,
floatfactor: if floatfactor == 1.0 {
Default::default()
} else {
floatfactor
},
is_zero: scalar.is_zero(),
..Default::default()
}
Expand Down Expand Up @@ -102,6 +132,11 @@ mod test {
#[case(ScalarN::from_phase((-1,2)))]
#[case(ScalarN::real(2.0))]
#[case(ScalarN::complex(1.0, 1.0))]
#[case(ScalarN::from_int_coeffs(&[0, 1, 0, -1]))]
#[case(ScalarN::from_int_coeffs(&[0, 7, 0, 7]))]
#[case(ScalarN::from_int_coeffs(&[-2, 0, -2, 0]))]
#[case(ScalarN::from_int_coeffs(&[2, 0, -2, 0]))]
#[case(ScalarN::from_int_coeffs(&[2, 0, 0, 0, 0, 0]))]
fn scalar_roundtrip(#[case] scalar: ScalarN) -> Result<(), JsonError> {
let json_scalar = JsonScalar::from_scalar(&scalar);
let decoded: ScalarN = json_scalar.to_scalar()?;
Expand Down
57 changes: 54 additions & 3 deletions quizx/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

use approx::AbsDiffEq;
use num::complex::Complex;
use num::integer::sqrt;
pub use num::traits::identities::{One, Zero};
use num::{integer, Integer};
use num::{integer, Integer, Rational64};
use std::cmp::min;
use std::f64::consts::PI;
use std::fmt;
Expand Down Expand Up @@ -167,8 +168,39 @@ impl<T: Coeffs> Scalar<T> {

/// Returns the phase of the scalar, expressed as half turns.
///
/// As [`Phase`] is encoded as a rational number, this method may lose precision.
/// We deal with Pi/4 phases of Scalar4 (Clifford+T) exactly. For other cases, [`Phase`] is encoded as a rational
/// number, which may lose precision.
pub fn phase(&self) -> Phase {
if let Exact(_, coeffs) = self {
if coeffs.len() == 4 {
// cases where the phase is a multiple of 1/4 are handled exactly
match coeffs.iter_coeffs().collect::<Vec<_>>().as_slice() {
[a, b, 0, c] if -b == *c => {
return Phase::new(((-a - b * sqrt(2)).signum() as i64 + 1) / 2)
}
[0, c, 0, 0] => {
return Phase::new(Rational64::new(if *c > 0 { 1 } else { 5 }, 4))
}
[0, 0, c, 0] => {
return Phase::new(Rational64::new(if *c > 0 { 1 } else { 3 }, 2))
}
[0, 0, 0, c] => {
return Phase::new(Rational64::new(if *c > 0 { 3 } else { 7 }, 4))
}
[c, 0, d, 0] if c == d => {
return Phase::new(Rational64::new(if *c > 0 { 1 } else { 5 }, 4))
}
[0, c, 0, d] if c == d => {
return Phase::new(Rational64::new(if *c > 0 { 1 } else { 3 }, 2))
}
[d, 0, c, 0] if -c == *d => {
return Phase::new(Rational64::new(if *c > 0 { 3 } else { 7 }, 4))
}
_ => {}
}
}
}
// for other cases, we use the floating point representation
Phase::from_f64(self.complex_value().arg() / PI)
}

Expand Down Expand Up @@ -632,7 +664,7 @@ impl<T: Coeffs> PartialEq for Scalar<T> {

all_eq
}
_ => false,
_ => self.complex_value() == other.complex_value(),
}
}
}
Expand Down Expand Up @@ -705,6 +737,7 @@ mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use num::Rational64;
use rstest::rstest;

#[test]
fn approx_mul() {
Expand Down Expand Up @@ -766,6 +799,24 @@ mod tests {
);
}

#[rstest]
#[case(ScalarN::from_int_coeffs(&[3, 0, 0, 0]))]
#[case(ScalarN::from_int_coeffs(&[0, -2, 0, 0]))]
#[case(ScalarN::from_int_coeffs(&[0, 0, 1, 0]))]
#[case(ScalarN::from_int_coeffs(&[0, 0, 0, 1]))]
#[case(ScalarN::from_int_coeffs(&[0, 0, 0, -1]))]
#[case(ScalarN::from_int_coeffs(&[2, 0, 2, 0]))]
#[case(ScalarN::from_int_coeffs(&[2, 0, -2, 0]))]
#[case(ScalarN::from_int_coeffs(&[-2, 0, -2, 0]))]
#[case(ScalarN::from_int_coeffs(&[0, 1, 0, 1]))]
#[case(ScalarN::from_int_coeffs(&[0, 1, 0, -1]))]
#[case(ScalarN::from_int_coeffs(&[0, -2, 0, -2]))]
#[case(ScalarN::from_int_coeffs(&[0, 2, 0, -2]))]
#[case(ScalarN::from_int_coeffs(&[-1, 2, 3, -4]))]
fn get_phase(#[case] s: ScalarN) {
assert_abs_diff_eq!(s.phase().to_f64(), s.complex_value().arg() / PI);
}

#[test]
fn additions() {
let s = ScalarN::from_int_coeffs(&[1, 2, 3, 4]);
Expand Down

0 comments on commit 9a6c6cb

Please sign in to comment.