Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix qml.center with linear combinations #6049

Merged
merged 20 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@
* Fix `jax.grad` + `jax.jit` not working for `AmplitudeEmbedding`, `StatePrep` and `MottonenStatePreparation`.
[(#5620)](https://github.com/PennyLaneAI/pennylane/pull/5620)

* Fixed a bug in `qml.center` that omitted elements from the center if they were
linear combinations of input elements.
[(#6049)](https://github.com/PennyLaneAI/pennylane/pull/6049)

* Fix a bug where the global phase returned by `one_qubit_decomposition` gained a broadcasting dimension.
[(#5923)](https://github.com/PennyLaneAI/pennylane/pull/5923)

Expand Down Expand Up @@ -282,5 +286,5 @@ Vincent Michaud-Rioux,
Anurav Modak,
Mudit Pandey,
Erik Schultheis,
nate stemen,
David Wierichs,
Nate Stemen,
David Wierichs,
147 changes: 127 additions & 20 deletions pennylane/pauli/dla/center.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,47 @@
from typing import Union

import numpy as np
from scipy.linalg import norm, null_space

from pennylane.operation import Operator
from pennylane.pauli import PauliSentence, PauliWord
from pennylane.pauli.dla import structure_constants


def _intersect_bases(basis_0, basis_1):
r"""Compute the intersection of two vector spaces that are given by a basis each.
This is done by constructing a matrix [basis_0 | -basis_1] and computing its null space
in form of vectors (u, v)^T, which is equivalent to solving the equation
``basis_0 @ u = basis_1 @ v``.
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
Given a basis for this null space, the vectors ``basis_0 @ u`` (or equivalently
``basis_1 @ v``) form a basis for the intersection of the vector spaces.

Also see https://math.stackexchange.com/questions/25371/how-to-find-a-basis-for-the-intersection-of-two-vector-spaces-in-mathbbrn
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
"""
# Compute (orthonormal) basis for the null space of the augmented matrix [basis_0, -basis_1]
augmented_basis = null_space(np.hstack([basis_0, -basis_1]))
# Compute basis_0 @ u for each vector u from the basis (u, v)^T in the augmented basis
intersection_basis = basis_0 @ augmented_basis[: basis_0.shape[1]]
# Normalize the output for cleaner results, because the augmented kernel was normalized
intersection_basis = intersection_basis / norm(intersection_basis, axis=0)
return intersection_basis


def _center_pauli_words(g, pauli):
"""Compute the center of an algebra given in a PauliWord basis."""
d = len(g)
commutators = np.zeros((d, d), dtype=int)
for (j, op1), (k, op2) in combinations(enumerate(g), r=2):
if not op1.commutes_with(op2):
commutators[j, k] = 1 # dummy value to indicate operators dont commute
commutators[k, j] = 1

mask = np.all(commutators == 0, axis=0)
res = list(np.array(g)[mask])

if not pauli:
res = [op.operation() for op in res]
return res


def center(
Expand All @@ -33,9 +71,11 @@ def center(
.. math:: \mathfrak{\xi}(\mathfrak{g}) := \{h \in \mathfrak{g} | [h, h_i]=0 \ \forall h_i \in \mathfrak{g} \}

Args:
g (List[Union[Operator, PauliSentence, PauliWord]]): List of operators for which to find the center.
pauli (bool): Indicates whether it is assumed that :class:`~.PauliSentence` or :class:`~.PauliWord` instances are input and returned.
This can help with performance to avoid unnecessary conversions to :class:`~pennylane.operation.Operator`
g (List[Union[Operator, PauliSentence, PauliWord]]): List of operators that spans
the algebra for which to find the center.
pauli (bool): Indicates whether it is assumed that :class:`~.PauliSentence` or
:class:`~.PauliWord` instances are input and returned. This can help with performance
to avoid unnecessary conversions to :class:`~pennylane.operation.Operator`
and vice versa. Default is ``False``.

Returns:
Expand All @@ -56,23 +96,90 @@ def center(
>>> qml.center(g)
[X(0)]

.. details::
:title: Derivation
:href: derivation

The center :math:`\mathfrak{z}(\mathfrak{k})` of an algebra :math:`\mathfrak{k}`
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
can be computed in the following steps. First, compute the
:func:`~.pennylane.structure_constants`, or adjoint representation, of the algebra
with respect to some basis :math:`\mathbb{B}` of :math:`\mathfrak{k}`.
The center of :math:`\mathfrak{k}` is then given by

.. math::

\mathfrak{z}(\mathfrak{k}) = \operatorname{span}\left\{\bigcap_{x\in\mathbb{B}}
\operatorname{ker}(\operatorname{ad}_x)\right\},

i.e., the intersection of the kernels, or null spaces, of all basis elements in the
adjoint representation.

The kernel can be computed with ``scipy.linalg.null_space``, and vector space
intersections are computed recursively from pairwise intersections. The intersection
between two vectors spaces :math:`V_1` and :math:`V_2` given by (orthonormal) bases
:math:`\mathbb{B}_i` can be computed from
:math:`\operatorname{ker}([\mathbb{B}_1 | -\mathbb{B}_2])`. For an (orthonormal)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
basis :math:`\{(u_1^{(i)}, u_2^{(i)})^T\}_i` of this kernel, a basis of the
intersection space :math:`V_1 \cap V_2` is given by :math:`\{\mathbb{B}_1 u_1^{(i)}\}_i`
(or equivalently by :math:`\{\mathbb{B}_2 u_2^{(i)}\}_i`).
Also see [this post](https://math.stackexchange.com/questions/25371/how-to-find-a-basis-for-the-intersection-of-two-vector-spaces-in-mathbbrn)
for details.
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

If the input consists of :class:`~.pennylane.PauliWord` instances only, we can
instead compute pairwise commutators and know that the center consists solely of
basis elements that commute with all other basis elements. This can be seen in the
following way.
Assume that the center elements identified based on the basis have been removed
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
already and we are left with a basis :math:`\mathbb{B}=\{p_i\}_i` of Pauli
words such that :math:`\forall i\ \exists j:\ [p_i, p_j] \neq 0`. Assume that there is
another center element :math:`x`, which was missed before because it is a linear
combination of Pauli words:

.. math::

\forall j: \ [x, p_j] = [\sum_i x_i p_i, p_j] = 0.

As products of Paulis are unique when fixing one of the factors (:math:`p_j` is fixed
above), we then know that

.. math::

&\forall j: \ 0 = \sum_i x_i [p_i, p_j] = 2 \sum_i x_i \chi_{i,j} p_ip_j\\
\Rightarrow &\forall i,j such that \chi_{i,j}\neq 0: x_i = 0,
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

where denoted by :math:`\chi_{i,j}` an indicator that is :math:`0` if the commutator
:math:`[p_i, p_j]` vanishes and :math:`1` else.
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
However, we know that for each :math:`i` there is a :math:`j` such that
:math:`\chi_{i,j}\neq 0`. This means that :math:`x_i = 0\ \forall i` and therefore
:math:`x`, so that we did not miss a center element.
Qottmann marked this conversation as resolved.
Show resolved Hide resolved
"""
if len(g) < 2:
# A length-zero list has zero center, a length-one list has full center
return g
if all(isinstance(x, PauliWord) for x in g):
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return _center_pauli_words(g, pauli)

adjoint_repr = structure_constants(g, pauli)
# Start kernels intersection with kernel of first DLA element
kernel_intersection = null_space(adjoint_repr[0])
for ad_x in adjoint_repr[1:]:
# Compute the next kernel and intersect it with previous intersection
next_kernel = null_space(ad_x)
kernel_intersection = _intersect_bases(kernel_intersection, next_kernel)

# If the intersection is zero-dimensional, exit early
if kernel_intersection.shape[1] == 0:
return []

# Construct operators from numerical output and convert to desired format
res = [sum(c * x for c, x in zip(c_coeffs, g)) for c_coeffs in kernel_intersection.T]

have_paulis = all(isinstance(x, (PauliWord, PauliSentence)) for x in res)
if pauli or have_paulis:
_ = [el.simplify() for el in res]
Qottmann marked this conversation as resolved.
Show resolved Hide resolved
if not pauli:
res = [el.operation() for el in res]
else:
res = [el.simplify() for el in res]

if not pauli:
g = [o.pauli_rep for o in g]

d = len(g)
commutators = np.zeros((d, d), dtype=int)
for (j, op1), (k, op2) in combinations(enumerate(g), r=2):
res = op1.commutator(op2)
res.simplify()
if res != PauliSentence({}):
commutators[j, k] = 1 # dummy value to indicate operators dont commute
commutators[k, j] = 1

mask = np.all(commutators == 0, axis=0)
res = list(np.array(g)[mask])

if not pauli:
res = [op.operation() for op in res]
return res
54 changes: 37 additions & 17 deletions tests/pauli/dla/test_center.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@
# limitations under the License.
"""Tests for pennylane/dla/center.py functionality"""

import numpy as np
import pytest

import pennylane as qml
from pennylane.pauli import PauliSentence, center
from pennylane.pauli import PauliSentence, PauliWord, center


def test_trivial_center():
"""Test that the center of an empty list of generators is an empty list of generators."""
ops = []
res = center(ops)
assert res == []
assert center([]) == []


DLA_CENTERS = (
Expand All @@ -38,8 +37,7 @@ def test_trivial_center():
@pytest.mark.parametrize("ops, true_res", DLA_CENTERS)
def test_center(ops, true_res):
"""Test centers with Identity operators or non-overlapping wires"""
res = center(ops)
assert res == true_res
assert center(ops) == true_res


@pytest.mark.parametrize("ops, true_res", DLA_CENTERS)
Expand All @@ -49,36 +47,58 @@ def test_center_pauli(ops, true_res):
res = center(ops, pauli=True)

assert all(isinstance(op, PauliSentence) for op in res)
true_res = [op.pauli_rep for op in true_res]
assert res == true_res
assert res == [op.pauli_rep for op in true_res]


@pytest.mark.parametrize("pauli", [False, True])
def test_center_pauli_word_pauli_True(pauli):
def test_center_pauli_word(pauli):
"""Test that PauliWord instances can be passed for both pauli=True/False"""
ops = [
qml.pauli.PauliWord({0: "X"}),
qml.pauli.PauliWord({0: "X", 1: "X"}),
qml.pauli.PauliWord({1: "Y"}),
]
words = [{0: "X"}, {0: "X", 1: "X"}, {1: "Y"}, {0: "X", 1: "Z"}]
ops = list(map(PauliWord, words))
if pauli:
assert qml.center(ops, pauli=pauli) == [qml.pauli.PauliWord({0: "X"})]
assert qml.center(ops, pauli=pauli) == [PauliWord({0: "X"})]
else:
assert qml.center(ops, pauli=pauli) == [qml.X(0)]


@pytest.mark.parametrize("pauli", [False, True])
def test_center_pauli_sentence(pauli):
"""Test that PauliSentence instances can be passed for both pauli=True/False"""
words = [{0: "X"}, {0: "X", 1: "X"}, {1: "Y"}, {0: "X", 1: "Z"}]
words = list(map(PauliWord, words))
sentences = [
{words[0]: 0.5, words[1]: 3.2},
{words[0]: -0.2, words[2]: 2.5},
{words[2]: 1.2, words[3]: 0.72, words[1]: 0.6},
{words[1]: 0.9, words[2]: 1.8},
]
sentences = list(map(PauliSentence, sentences))
if pauli:
cent = qml.center(sentences, pauli=pauli)
assert isinstance(cent, list) and len(cent) == 1
assert isinstance(cent[0], PauliSentence)
assert PauliWord({0: "X"}) in cent[0]
else:
cent = qml.center(sentences, pauli=pauli)
assert isinstance(cent, list) and len(cent) == 1
assert isinstance(cent[0], qml.ops.op_math.SProd)
assert cent[0].base == qml.X(0)


c = 1 / np.sqrt(2)

GENERATOR_CENTERS = (
([qml.X(0), qml.X(0) @ qml.X(1), qml.Y(1)], [qml.X(0)]),
([qml.X(0) @ qml.X(1), qml.Y(1), qml.X(0)], [qml.X(0)]),
([qml.X(0) @ qml.X(1), qml.Y(1), qml.X(1)], []),
([qml.X(0) @ qml.X(1), qml.Y(1), qml.Z(0)], []),
([p(0) @ p(1) for p in [qml.X, qml.Y, qml.Z]], [p(0) @ p(1) for p in [qml.X, qml.Y, qml.Z]]),
([qml.X(0), qml.X(1), sum(p(0) @ p(1) for p in [qml.Y, qml.Z])], [c * qml.X(0) + c * qml.X(1)]),
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
)


@pytest.mark.parametrize("generators, true_res", GENERATOR_CENTERS)
def test_center_dla(generators, true_res):
"""Test computing the center for a non-trivial DLA"""
g = qml.pauli.lie_closure(generators)
res = center(g)
assert res == true_res
assert center(g) == true_res
Loading