Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default values to for_loop #6068

Merged
merged 8 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
[(#6041)](https://github.com/PennyLaneAI/pennylane/pull/6041)
[(#6064)](https://github.com/PennyLaneAI/pennylane/pull/6064)

* `qml.for_loop` now supports `range`-like syntax with default `step=1`.
[(#6068)](https://github.com/PennyLaneAI/pennylane/pull/6068)

* Removed `semantic_version` from the list of required packages in PennyLane.
[(#5836)](https://github.com/PennyLaneAI/pennylane/pull/5836)

Expand Down
38 changes: 26 additions & 12 deletions pennylane/compiler/qjit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,9 @@ def __call__(self, *init_state):
return self._call_capture_disabled(*init_state)


def for_loop(lower_bound, upper_bound, step):
"""A :func:`~.qjit` compatible for-loop for PennyLane programs. When
def for_loop(start, stop=None, step=1):
"""for_loop([start, ]stop[, step])
A :func:`~.qjit` compatible for-loop for PennyLane programs. When
used without :func:`~.qjit`, this function will fall back to a standard
Python for loop.

Expand All @@ -504,18 +505,29 @@ def for_loop(lower_bound, upper_bound, step):

.. code-block:: python

def for_loop(lower_bound, upper_bound, step, loop_fn, *args):
for i in range(lower_bound, upper_bound, step):
def for_loop(start, stop, step, loop_fn, *args):
for i in range(start, stop, step):
args = loop_fn(i, *args)
return args

Unlike ``jax.cond.fori_loop``, the step can be negative if it is known at tracing time
(i.e., constant). If a non-constant negative step is used, the loop will produce no iterations.

.. note::

This function can be used in the following different ways:

1. ``for_loop(stop)``: Values are generated within the interval ``[0, stop)``
2. ``for_loop(start, stop)``: Values are generated within the interval ``[start, stop)``
3. ``for_loop(start, stop, step)``: Values are generated within the interval ``[start, stop)``,
with spacing between the values given by ``step``

Args:
lower_bound (int): starting value of the iteration index
upper_bound (int): (exclusive) upper bound of the iteration index
step (int): increment applied to the iteration index at the end of each iteration
start (int, optional): starting value of the iteration index.
The default start value is ``0``
stop (int): upper bound of the iteration index
step (int, optional): increment applied to the iteration index at the end of
each iteration. The default step size is ``1``

Returns:
Callable[[int, ...], ...]: A wrapper around the loop body function.
Expand Down Expand Up @@ -565,16 +577,18 @@ def loop_rx(i, x):
page for an overview of using quantum just-in-time compilation.

"""
if stop is None:
start, stop = 0, start

if active_jit := active_compiler():
compilers = AvailableCompilers.names_entrypoints
ops_loader = compilers[active_jit]["ops"].load()
return ops_loader.for_loop(lower_bound, upper_bound, step)
return ops_loader.for_loop(start, stop, step)

# if there is no active compiler, simply interpret the for loop
# via the Python interpretor.
def _decorator(body_fn):
"""Transform that will call the input ``body_fn`` within a for loop defined by the closure variables lower_bound, upper_bound, and step.
"""Transform that will call the input ``body_fn`` within a for loop defined by the closure variables start, stop, and step.

Args:
body_fn (Callable): The function called within the for loop. Note that the loop body
Expand All @@ -584,14 +598,14 @@ def _decorator(body_fn):
returned from the function.

Closure Variables:
lower_bound (int): starting value of the iteration index
upper_bound (int): (exclusive) upper bound of the iteration index
start (int): starting value of the iteration index
stop (int): (exclusive) upper bound of the iteration index
step (int): increment applied to the iteration index at the end of each iteration

Returns:
Callable: a callable with the same signature as ``body_fn``
"""
return ForLoopCallable(lower_bound, upper_bound, step, body_fn)
return ForLoopCallable(start, stop, step, body_fn)

return _decorator

Expand Down
68 changes: 68 additions & 0 deletions tests/capture/test_capture_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,74 @@ def loop(_, a):
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize("array", [jax.numpy.zeros(0), jax.numpy.zeros(5)])
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
def test_for_loop_defaults(self, array):
"""Test simple for-loop primitive using default values."""

def fn(arg):

a = jax.numpy.ones(arg.shape)

@qml.for_loop(0, 10, 1)
def loop1(_, a):
return a

@qml.for_loop(10, 1)
def loop2(_, a):
return a

@qml.for_loop(10)
def loop3(_, a):
return a

r1, r2, r3 = loop1(a), loop2(a), loop3(a)
return r1, r2, r3

expected = jax.numpy.ones(array.shape)
result = fn(array)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

jaxpr = jax.make_jaxpr(fn)(array)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize(
"array, expected",
[
(jax.numpy.zeros(5), jax.numpy.array([0, 1, 4, 9, 16])),
(jax.numpy.zeros(10), jax.numpy.array([0, 1, 4, 9, 16, 25, 36, 49, 64, 81])),
],
)
def test_for_loop_default(self, array, expected):
"""Test simple for-loop primitive using default values."""

def fn(arg):

stop = arg.shape[0]
a = jax.numpy.ones(stop)

@qml.for_loop(0, stop, 1)
def loop1(i, a):
return a.at[i].set(i**2)

@qml.for_loop(0, stop)
def loop2(i, a):
return a.at[i].set(i**2)

@qml.for_loop(stop)
def loop3(i, a):
return a.at[i].set(i**2)

r1, r2, r3 = loop1(a), loop2(a), loop3(a)
return r1, r2, r3

result = fn(array)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

jaxpr = jax.make_jaxpr(fn)(array)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize("array", [jax.numpy.zeros(0), jax.numpy.zeros(5)])
def test_for_loop_shared_indbidx(self, array):
"""Test for-loops with shared dynamic input dimensions."""
Expand Down
Loading