diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1c4e709ec95..57ed1b27da5 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -50,6 +50,9 @@ [(#6041)](https://github.com/PennyLaneAI/pennylane/pull/6041) [(#6064)](https://github.com/PennyLaneAI/pennylane/pull/6064) +* `qml.for_loop` now supports `range`-like syntax with default `step=1`. + [(#6068)](https://github.com/PennyLaneAI/pennylane/pull/6068) + * Removed `semantic_version` from the list of required packages in PennyLane. [(#5836)](https://github.com/PennyLaneAI/pennylane/pull/5836) diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 72f28fb6a8c..25b454c6c6a 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -481,8 +481,9 @@ def __call__(self, *init_state): return self._call_capture_disabled(*init_state) -def for_loop(lower_bound, upper_bound, step): - """A :func:`~.qjit` compatible for-loop for PennyLane programs. When +def for_loop(start, stop=None, step=1): + """for_loop([start, ]stop[, step]) + A :func:`~.qjit` compatible for-loop for PennyLane programs. When used without :func:`~.qjit`, this function will fall back to a standard Python for loop. @@ -504,18 +505,29 @@ def for_loop(lower_bound, upper_bound, step): .. code-block:: python - def for_loop(lower_bound, upper_bound, step, loop_fn, *args): - for i in range(lower_bound, upper_bound, step): + def for_loop(start, stop, step, loop_fn, *args): + for i in range(start, stop, step): args = loop_fn(i, *args) return args Unlike ``jax.cond.fori_loop``, the step can be negative if it is known at tracing time (i.e., constant). If a non-constant negative step is used, the loop will produce no iterations. + .. note:: + + This function can be used in the following different ways: + + 1. ``for_loop(stop)``: Values are generated within the interval ``[0, stop)`` + 2. ``for_loop(start, stop)``: Values are generated within the interval ``[start, stop)`` + 3. ``for_loop(start, stop, step)``: Values are generated within the interval ``[start, stop)``, + with spacing between the values given by ``step`` + Args: - lower_bound (int): starting value of the iteration index - upper_bound (int): (exclusive) upper bound of the iteration index - step (int): increment applied to the iteration index at the end of each iteration + start (int, optional): starting value of the iteration index. + The default start value is ``0`` + stop (int): upper bound of the iteration index + step (int, optional): increment applied to the iteration index at the end of + each iteration. The default step size is ``1`` Returns: Callable[[int, ...], ...]: A wrapper around the loop body function. @@ -565,16 +577,18 @@ def loop_rx(i, x): page for an overview of using quantum just-in-time compilation. """ + if stop is None: + start, stop = 0, start if active_jit := active_compiler(): compilers = AvailableCompilers.names_entrypoints ops_loader = compilers[active_jit]["ops"].load() - return ops_loader.for_loop(lower_bound, upper_bound, step) + return ops_loader.for_loop(start, stop, step) # if there is no active compiler, simply interpret the for loop # via the Python interpretor. def _decorator(body_fn): - """Transform that will call the input ``body_fn`` within a for loop defined by the closure variables lower_bound, upper_bound, and step. + """Transform that will call the input ``body_fn`` within a for loop defined by the closure variables start, stop, and step. Args: body_fn (Callable): The function called within the for loop. Note that the loop body @@ -584,14 +598,14 @@ def _decorator(body_fn): returned from the function. Closure Variables: - lower_bound (int): starting value of the iteration index - upper_bound (int): (exclusive) upper bound of the iteration index + start (int): starting value of the iteration index + stop (int): (exclusive) upper bound of the iteration index step (int): increment applied to the iteration index at the end of each iteration Returns: Callable: a callable with the same signature as ``body_fn`` """ - return ForLoopCallable(lower_bound, upper_bound, step, body_fn) + return ForLoopCallable(start, stop, step, body_fn) return _decorator diff --git a/tests/capture/test_capture_for_loop.py b/tests/capture/test_capture_for_loop.py index f6343588c36..d506d9b6a11 100644 --- a/tests/capture/test_capture_for_loop.py +++ b/tests/capture/test_capture_for_loop.py @@ -64,6 +64,74 @@ def loop(_, a): res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array) assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.parametrize("array", [jax.numpy.zeros(0), jax.numpy.zeros(5)]) + def test_for_loop_defaults(self, array): + """Test simple for-loop primitive using default values.""" + + def fn(arg): + + a = jax.numpy.ones(arg.shape) + + @qml.for_loop(0, 10, 1) + def loop1(_, a): + return a + + @qml.for_loop(10, 1) + def loop2(_, a): + return a + + @qml.for_loop(10) + def loop3(_, a): + return a + + r1, r2, r3 = loop1(a), loop2(a), loop3(a) + return r1, r2, r3 + + expected = jax.numpy.ones(array.shape) + result = fn(array) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + jaxpr = jax.make_jaxpr(fn)(array) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + @pytest.mark.parametrize( + "array, expected", + [ + (jax.numpy.zeros(5), jax.numpy.array([0, 1, 4, 9, 16])), + (jax.numpy.zeros(10), jax.numpy.array([0, 1, 4, 9, 16, 25, 36, 49, 64, 81])), + ], + ) + def test_for_loop_default(self, array, expected): + """Test simple for-loop primitive using default values.""" + + def fn(arg): + + stop = arg.shape[0] + a = jax.numpy.ones(stop) + + @qml.for_loop(0, stop, 1) + def loop1(i, a): + return a.at[i].set(i**2) + + @qml.for_loop(0, stop) + def loop2(i, a): + return a.at[i].set(i**2) + + @qml.for_loop(stop) + def loop3(i, a): + return a.at[i].set(i**2) + + r1, r2, r3 = loop1(a), loop2(a), loop3(a) + return r1, r2, r3 + + result = fn(array) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + jaxpr = jax.make_jaxpr(fn)(array) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.parametrize("array", [jax.numpy.zeros(0), jax.numpy.zeros(5)]) def test_for_loop_shared_indbidx(self, array): """Test for-loops with shared dynamic input dimensions."""