Skip to content

Commit

Permalink
Add support for while_loop capture (#6064)
Browse files Browse the repository at this point in the history
**Context:** This add support for capturing `while_loop` into plxpr.

**Description of the Change:** The function's behaviour resembles the
one for `for_loop`, and works on capturing both `cond_fn` and `body_fn`
and executing both of them at runtime.

**Benefits:** `while_loop` will enjoy `plxpr` when capture is enabled

**Possible Drawbacks:** N/A

**Related GitHub Issues:** [sc-66773]
  • Loading branch information
obliviateandsurrender committed Aug 8, 2024
1 parent 7d7d51b commit bb51b79
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 20 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

* `qml.for_loop` can now be captured into plxpr.
[(#6041)](https://github.com/PennyLaneAI/pennylane/pull/6041)
[(#6064)](https://github.com/PennyLaneAI/pennylane/pull/6064)

* Removed `semantic_version` from the list of required packages in PennyLane.
[(#5836)](https://github.com/PennyLaneAI/pennylane/pull/5836)
Expand Down
95 changes: 75 additions & 20 deletions pennylane/compiler/qjit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,19 +307,12 @@ def sum_abstracted(arr):


def while_loop(cond_fn):
"""A :func:`~.qjit` compatible while-loop for PennyLane programs.
.. note::
This function only supports the Catalyst compiler. See
:func:`catalyst.while_loop` for more details.
Please see the Catalyst :doc:`quickstart guide <catalyst:dev/quick_start>`,
as well as the :doc:`sharp bits and debugging tips <catalyst:dev/sharp_bits>`
page for an overview of the differences between Catalyst and PennyLane.
"""A :func:`~.qjit` compatible for-loop for PennyLane programs. When
used without :func:`~.qjit`, this function will fall back to a standard
Python for loop.
This decorator provides a functional version of the traditional while
loop, similar to `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`__.
This decorator provides a functional version of the traditional while loop,
similar to `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`__.
That is, any variables that are modified across iterations need to be provided as
inputs and outputs to the loop body function:
Expand All @@ -329,10 +322,9 @@ def while_loop(cond_fn):
- Output arguments contain the value at the end of the iteration. The
outputs are then fed back as inputs to the next iteration.
The final iteration values are also returned from the
transformed function.
The final iteration values are also returned from the transformed function.
The semantics of ``while_loop`` are given by the following Python pseudo-code:
The semantics of ``while_loop`` are given by the following Python pseudocode:
.. code-block:: python
Expand All @@ -358,7 +350,6 @@ def while_loop(cond_fn, body_fn, *args):
dev = qml.device("lightning.qubit", wires=1)
@qml.qjit
@qml.qnode(dev)
def circuit(x: float):
Expand All @@ -369,12 +360,19 @@ def loop_rx(x):
return x ** 2
# apply the while loop
final_x = loop_rx(x)
loop_rx(x)
return qml.expval(qml.Z(0)), final_x
return qml.expval(qml.Z(0))
>>> circuit(1.6)
(array(-0.02919952), array(2.56))
-0.02919952
``while_loop`` is also :func:`~.qjit` compatible; when used with the
:func:`~.qjit` decorator, the while loop will not be unrolled, and instead
will be captured as-is during compilation and executed during runtime:
>>> qml.qjit(circuit)(1.6)
Array(-0.02919952, dtype=float64)
"""

if active_jit := active_compiler():
Expand All @@ -401,6 +399,37 @@ def _decorator(body_fn: Callable) -> Callable:
return _decorator


@functools.lru_cache
def _get_while_loop_qfunc_prim():
"""Get the while_loop primitive for quantum functions."""

import jax # pylint: disable=import-outside-toplevel

while_loop_prim = jax.core.Primitive("while_loop")
while_loop_prim.multiple_results = True

@while_loop_prim.def_impl
def _(*jaxpr_args, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond):

jaxpr_consts_body = jaxpr_args[:n_consts_body]
jaxpr_consts_cond = jaxpr_args[n_consts_body : n_consts_body + n_consts_cond]
init_state = jaxpr_args[n_consts_body + n_consts_cond :]

# If cond_fn(*init_state) is False, return the initial state
fn_res = init_state
while jax.core.eval_jaxpr(jaxpr_cond_fn.jaxpr, jaxpr_consts_cond, *fn_res)[0]:
fn_res = jax.core.eval_jaxpr(jaxpr_body_fn.jaxpr, jaxpr_consts_body, *fn_res)

return fn_res

@while_loop_prim.def_abstract_eval
def _(*_, jaxpr_body_fn, **__):

return jaxpr_body_fn.out_avals

return while_loop_prim


class WhileLoopCallable: # pylint:disable=too-few-public-methods
"""Base class to represent a while loop. This class
when called with an initial state will execute the while
Expand All @@ -415,7 +444,7 @@ def __init__(self, cond_fn, body_fn):
self.cond_fn = cond_fn
self.body_fn = body_fn

def __call__(self, *init_state):
def _call_capture_disabled(self, *init_state):
args = init_state
fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None

Expand All @@ -425,6 +454,32 @@ def __call__(self, *init_state):

return fn_res

def _call_capture_enabled(self, *init_state):

import jax # pylint: disable=import-outside-toplevel

while_loop_prim = _get_while_loop_qfunc_prim()

jaxpr_body_fn = jax.make_jaxpr(self.body_fn)(*init_state)
jaxpr_cond_fn = jax.make_jaxpr(self.cond_fn)(*init_state)

return while_loop_prim.bind(
*jaxpr_body_fn.consts,
*jaxpr_cond_fn.consts,
*init_state,
jaxpr_body_fn=jaxpr_body_fn,
jaxpr_cond_fn=jaxpr_cond_fn,
n_consts_body=len(jaxpr_body_fn.consts),
n_consts_cond=len(jaxpr_cond_fn.consts),
)

def __call__(self, *init_state):

if qml.capture.enabled():
return self._call_capture_enabled(*init_state)

return self._call_capture_disabled(*init_state)


def for_loop(lower_bound, upper_bound, step):
"""A :func:`~.qjit` compatible for-loop for PennyLane programs. When
Expand Down
217 changes: 217 additions & 0 deletions tests/capture/test_capture_while_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for capturing for while loops into jaxpr.
"""

import numpy as np
import pytest

import pennylane as qml

pytestmark = pytest.mark.jax

jax = pytest.importorskip("jax")


@pytest.fixture(autouse=True)
def enable_disable_plxpr():
"""Enable and disable the PennyLane JAX capture context manager."""
qml.capture.enable()
yield
qml.capture.disable()


class TestCaptureWhileLoop:
"""Tests for capturing for while loops into jaxpr."""

@pytest.mark.parametrize("x", [1.6, 2.4])
def test_while_loop_simple(self, x):
"""Test simple while-loop primitive"""

def fn(x):

@qml.while_loop(lambda x: x < 2)
def loop(x):
return x**2

x2 = loop(x)
return x2

expected = x**2 if x < 2 else x
result = fn(x)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

jaxpr = jax.make_jaxpr(fn)(x)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize("array", [jax.numpy.zeros(0), jax.numpy.zeros(5)])
def test_while_loop_dynamic_array(self, array):
"""Test while loops with dynamic array inputs."""

def fn(arg):

a, b = jax.numpy.ones(arg.shape, dtype=float), jax.numpy.ones(arg.shape, dtype=float)

# Note: lambda *_, idx: idx < 10 doesn't work - necessary keyword argument not provided
@qml.while_loop(lambda *args: args[-1] < 10)
def loop(a, b, idx):
return a + b, b + a, idx + 2

return loop(a, b, 0)

res_arr1, res_arr2, res_idx = fn(array)
expected = 2**5 * jax.numpy.ones(*array.shape)
assert jax.numpy.allclose(res_arr1, res_arr2)
assert jax.numpy.allclose(res_arr1, expected), f"Expected {expected}, but got {res_arr1}"

jaxpr = jax.make_jaxpr(fn)(array)
res_arr1_jxpr, res_arr2_jxpr, res_idx_jxpr = jax.core.eval_jaxpr(
jaxpr.jaxpr, jaxpr.consts, array
)

assert np.allclose(res_arr1_jxpr, res_arr2_jxpr)
assert np.allclose(res_arr1_jxpr, expected), f"Expected {expected}, but got {res_arr1_jxpr}"
assert np.allclose(res_idx, res_idx_jxpr) and res_idx_jxpr == 10


class TestCaptureCircuitsWhileLoop:
"""Tests for capturing for while loops into jaxpr in the context of quantum circuits."""

def test_while_loop_capture(self):
"""Test that a while loop is correctly captured into a jaxpr."""

dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit():

@qml.while_loop(lambda i: i < 3)
def loop_fn(i):
qml.RX(i, wires=0)
return i + 1

_ = loop_fn(0)

return qml.expval(qml.Z(0))

result = circuit()
expected = np.cos(0 + 1 + 2)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

jaxpr = jax.make_jaxpr(circuit)()
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize("arg, expected", [(1.2, -0.16852022), (1.6, 0.598211352)])
def test_circuit_args(self, arg, expected):
"""Test that a while loop with arguments is correctly captured into a jaxpr."""

dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev)
def circuit(arg):

qml.Hadamard(wires=0)
arg1, arg2 = arg + 0.1, arg + 0.2

@qml.while_loop(lambda x: x < 2.0)
def loop_body(x):
qml.RZ(arg1, wires=0)
qml.RZ(arg2, wires=0)
qml.RX(x, wires=0)
qml.RY(jax.numpy.sin(x), wires=0)
return x**2

loop_body(arg)

return qml.expval(qml.Z(0))

result = circuit(arg)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

jaxpr = jax.make_jaxpr(circuit)(arg)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize("arg, expected", [(3, 5), (11, 21)])
def test_circuit_closure_vars(self, arg, expected):
"""Test that closure variables within a while loop are correctly captured via jaxpr."""

def circuit(x):
y = x + 1

def while_f(i):
return i < y

@qml.while_loop(while_f)
def f(i):
return 4 * i + 1

return f(0)

result = circuit(arg)
assert qml.math.allclose(result, expected), f"Expected {expected}, but got {result}"

jaxpr = jax.make_jaxpr(circuit)(arg)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize(
"upper_bound, arg, expected", [(3, 0.5, 0.00223126), (2, 12, 0.2653001)]
)
def test_while_loop_nested(self, upper_bound, arg, expected):
"""Test that a nested while loop is correctly captured into a jaxpr."""

dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit(upper_bound, arg):

# while loop with dynamic bounds
@qml.while_loop(lambda i: i < upper_bound)
def loop_fn(i):
qml.Hadamard(wires=i)
return i + 1

# nested while loops.
# outer while loop updates x
@qml.while_loop(lambda _, i: i < upper_bound)
def loop_fn_returns(x, i):
qml.RX(x, wires=i)

# inner while loop
@qml.while_loop(lambda j: j < upper_bound)
def inner(j):
qml.RZ(j, wires=0)
qml.RY(x**2, wires=0)
return j + 1

inner(i + 1)

return x + 0.1, i + 1

loop_fn(0)
loop_fn_returns(arg, 0)

return qml.expval(qml.Z(0))

args = [upper_bound, arg]
result = circuit(*args)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

jaxpr = jax.make_jaxpr(circuit)(*args)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

0 comments on commit bb51b79

Please sign in to comment.