Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Exp.decomposition jit compatible #6082

Merged
merged 12 commits into from
Aug 9, 2024
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@
* `qml.lie_closure` works with sums of Paulis.
[(#6023)](https://github.com/PennyLaneAI/pennylane/pull/6023)

* Workflows that parameterize the coefficients of `qml.exp` is now jit-compatible.
astralcai marked this conversation as resolved.
Show resolved Hide resolved
[(#6082)](https://github.com/PennyLaneAI/pennylane/pull/6082)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
3 changes: 3 additions & 0 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def adjoint_state_measurements(

def adjoint_ops(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator is supported by adjoint differentiation."""
if isinstance(op, qml.ops.Exp) and qml.math.is_abstract(op.data[0]):
# Skip validation of Exp in tracing because Exp.has_generator is not traceable.
return True
astralcai marked this conversation as resolved.
Show resolved Hide resolved
return not isinstance(op, MidMeasureMP) and (
op.num_params == 0
or not qml.operation.is_trainable(op)
Expand Down
3 changes: 3 additions & 0 deletions pennylane/gradients/parameter_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,9 @@


def _param_shift_stopping_condition(op) -> bool:
if isinstance(op, qml.ops.Exp) and qml.math.is_abstract(op.data[0]):
# Skip validation of Exp in tracing because Exp.has_decomposition is not traceable.
return True

Check warning on line 762 in pennylane/gradients/parameter_shift.py

View check run for this annotation

Codecov / codecov/patch

pennylane/gradients/parameter_shift.py#L762

Added line #L762 was not covered by tests
if not op.has_decomposition:
# let things without decompositions through without error
# error will happen when calculating parameter shift tapes
Expand Down
51 changes: 33 additions & 18 deletions pennylane/ops/op_math/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def decomposition(self):

return d

# pylint:disable=too-many-branches
def _recursive_decomposition(self, base: Operator, coeff: complex):
"""Decompose the exponential of ``base`` multiplied by ``coeff``.

Expand Down Expand Up @@ -284,6 +283,28 @@ def _recursive_decomposition(self, base: Operator, coeff: complex):
coeffs = [c * coeff for c in coeffs]
return self._trotter_decomposition(ops, coeffs)

if not qml.math.is_abstract(coeff) and not qml.math.isclose(qml.math.real(coeff), 0):
astralcai marked this conversation as resolved.
Show resolved Hide resolved
astralcai marked this conversation as resolved.
Show resolved Hide resolved

error_msg = f"The decomposition of the {self} operator is not defined."

if not self.num_steps: # if num_steps was not set
error_msg += (
" Please set a value to ``num_steps`` when instantiating the ``Exp`` operator "
"if a Suzuki-Trotter decomposition is required."
)

if self.base.is_hermitian:
error_msg += (
" Decomposition is not defined for real coefficients of hermitian operators."
)

raise DecompositionUndefinedError(error_msg)

return self._smart_decomposition(coeff, base)

def _smart_decomposition(self, coeff, base):
"""Decompose to an operator to an operator with a generator or a PauliRot if possible."""

# Store operator classes with generators
has_generator_types = []
has_generator_types_anywires = []
Expand All @@ -304,36 +325,31 @@ def _recursive_decomposition(self, base: Operator, coeff: complex):
# Some generators are not wire-ordered (e.g. OrbitalRotation)
mapped_wires_g = qml.map_wires(g, dict(zip(g.wires, base.wires)))

if qml.equal(mapped_wires_g, base) and math.real(coeff) == 0:
coeff = math.real(
-1j / c * coeff
) # cancel the coefficients added by the generator
if qml.equal(mapped_wires_g, base):
# Cancel the coefficients added by the generator
coeff = math.real(-1j / c * coeff)
return [op_class(coeff, g.wires)]

# could have absorbed the coefficient.
simplified_g = qml.simplify(qml.s_prod(c, mapped_wires_g))

if qml.equal(simplified_g, base) and math.real(coeff) == 0:
coeff = math.real(-1j * coeff) # cancel the coefficients added by the generator
if qml.equal(simplified_g, base):
# Cancel the coefficients added by the generator
coeff = math.real(-1j * coeff)
return [op_class(coeff, g.wires)]

if qml.pauli.is_pauli_word(base) and math.real(coeff) == 0:
if qml.pauli.is_pauli_word(base):
# Check if the exponential can be decomposed into a PauliRot gate
return self._pauli_rot_decomposition(base, coeff)

error_msg = f"The decomposition of the {self} operator is not defined. "
error_msg = f"The decomposition of the {self} operator is not defined."

if not self.num_steps: # if num_steps was not set
error_msg += (
"Please set a value to ``num_steps`` when instantiating the ``Exp`` operator "
" Please set a value to ``num_steps`` when instantiating the ``Exp`` operator "
"if a Suzuki-Trotter decomposition is required. "
)

if math.real(self.coeff) != 0 and self.base.is_hermitian:
error_msg += (
"Decomposition is not defined for real coefficients of hermitian operators."
)

raise DecompositionUndefinedError(error_msg)

@staticmethod
Expand All @@ -347,9 +363,8 @@ def _pauli_rot_decomposition(base: Operator, coeff: complex):
Returns:
List[Operator]: list containing the PauliRot operator
"""
coeff = math.real(
2j * coeff
) # need to cancel the coefficients added by PauliRot and Ising gates
# Cancel the coefficients added by PauliRot and Ising gates
coeff = math.real(2j * coeff)
pauli_word = qml.pauli.pauli_word_to_string(base)
if pauli_word == "I" * base.num_wires:
return []
Expand Down
49 changes: 44 additions & 5 deletions tests/ops/op_math/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
class TestInitialization:
"""Test the initialization process and standard properties."""

@pytest.mark.usefixtures("use_legacy_and_new_opmath")
def test_pauli_base(self, constructor):
"""Test initialization with no coeff and a simple base."""
base = qml.PauliX("a")
Expand Down Expand Up @@ -421,19 +420,19 @@ def test_non_pauli_word_base_no_decomposition(self):
assert not op.has_decomposition
with pytest.raises(
DecompositionUndefinedError,
match=re.escape(f"The decomposition of the {op} operator is not defined. "),
match=re.escape(f"The decomposition of the {op} operator is not defined."),
):
op.decomposition()

op = Exp(2 * qml.S(0) + qml.PauliZ(1), -0.5j, num_steps=100)
assert not op.has_decomposition
with pytest.raises(
DecompositionUndefinedError,
match=re.escape(f"The decomposition of the {op} operator is not defined. "),
match=re.escape(f"The decomposition of the {op} operator is not defined."),
):
op.decomposition()

@pytest.mark.usefixtures("use_legacy_opmath")
@pytest.mark.usefixtures("legacy_opmath_only")
def test_nontensor_tensor_no_decomposition(self):
"""Checks that accessing the decomposition throws an error if the base is a Tensor
object that is not a mathematical tensor"""
Expand Down Expand Up @@ -477,7 +476,6 @@ def test_decomposition_tensor_into_pauli_rot(self, base, base_string):

@pytest.mark.parametrize("op_name", all_qubit_operators)
@pytest.mark.parametrize("str_wires", (True, False))
@pytest.mark.usefixtures("use_legacy_and_new_opmath")
def test_generator_decomposition(self, op_name, str_wires):
"""Check that Exp decomposes into a specific operator if ``base`` corresponds to the
generator of that operator."""
Expand Down Expand Up @@ -796,6 +794,47 @@ def circ(phi):
grad = jax.grad(circ)(phi)
assert qml.math.allclose(grad, -jnp.sin(phi))

@pytest.mark.catalyst
@pytest.mark.external
def test_catalyst_qnode(self):
"""Test with Catalyst interface"""

pytest.importorskip("catalyst")

phi = 0.345

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def func(params):
qml.exp(qml.X(0), -0.5j * params)
return qml.expval(qml.Z(0))

res = func(phi)
assert qml.math.allclose(res, np.cos(phi))
grad = qml.grad(func)(phi)
assert qml.math.allclose(grad, -np.sin(phi))

@pytest.mark.jax
@pytest.mark.external
astralcai marked this conversation as resolved.
Show resolved Hide resolved
def test_jax_jit_qnode(self):
"""Tests with jax.jit"""

import jax
from jax import numpy as jnp

phi = jnp.array(0.345)

@jax.jit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def func(params):
qml.exp(qml.X(0), -0.5j * params)
return qml.expval(qml.Z(0))

res = func(phi)
assert qml.math.allclose(res, jnp.cos(phi))
grad = jax.grad(func)(phi)
assert qml.math.allclose(grad, -jnp.sin(phi))

@pytest.mark.tf
def test_tensorflow_qnode(self):
"""test the execution of a tensorflow qnode."""
Expand Down
Loading