From caf24e737f169d00171ce264e81ba120b5ece0f4 Mon Sep 17 00:00:00 2001 From: David Ittah Date: Wed, 7 Feb 2024 18:02:27 -0500 Subject: [PATCH] Fix the default value of `qubits_len` for the `qinst` primitive (#496) Two issues are being fixed: - The default value of `-1`, intended to represent all arguments are qubits, does not actually work. Instead, it introduces a off by one error, cutting off one of the qubit values. - The abstract evaluation did not work with the default value, and instead assumed the value was always set manually. Credit to @positr0nium for unearthing this :) --------- Co-authored-by: Ali Asadi <10773383+maliasadi@users.noreply.github.com> --- doc/changelog.md | 11 +++++-- frontend/catalyst/jax_primitives.py | 10 +++--- frontend/test/pytest/test_jax_primitives.py | 36 +++++++++++++++++++++ 3 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 frontend/test/pytest/test_jax_primitives.py diff --git a/doc/changelog.md b/doc/changelog.md index 094edb0a94..6996d04083 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -162,13 +162,18 @@ implemented in Catalyst's runtime: * `int64_t __catalyst__rt__array_get_size_1d(QirArray *)` * `int8_t *__catalyst__rt__array_get_element_ptr_1d(QirArray *, int64_t)` - + and the following functions were removed since the frontend does not generate them * `QirString *__catalyst__rt__qubit_to_string(QUBIT *)` * `QirString *__catalyst__rt__result_to_string(RESULT *)`

Bug fixes

+* Fix an issue when no qubit number was specified for the `qinst` primitive. The primitive now + correctly deduces the number of qubits when no gate parameters are present. This change is not + user facing. + [(#496)](https://github.com/PennyLaneAI/catalyst/pull/496) + * Fix the scatter operation lowering when `updatedWindowsDim` is empty. [(#475)](https://github.com/PennyLaneAI/catalyst/pull/475) @@ -249,9 +254,9 @@ Haochen Paul Wang. def f(x): def cnot_loop(j): qml.CNOT(wires=[j, jnp.mod((j + 1), 4)]) - + for_loop(0, 4, 1)(cnot_loop)() - + return qml.expval(qml.PauliZ(0)) ``` diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 6d973126e8..4e21959fbc 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -682,11 +682,11 @@ def _qinsert_lowering( # qinst # @qinst_p.def_abstract_eval -def _qinst_abstract_eval(*qubits_or_params, op=None, qubits_len=-1): - for idx in range(qubits_len): - qubit = qubits_or_params[idx] +def _qinst_abstract_eval(*qubits_or_params, op=None, qubits_len=None): + qubits = qubits_or_params[:qubits_len] + for qubit in qubits: assert isinstance(qubit, AbstractQbit) - return (AbstractQbit(),) * qubits_len + return (AbstractQbit(),) * len(qubits) @qinst_p.def_impl @@ -695,7 +695,7 @@ def _qinst_def_impl(ctx, *qubits_or_params, op, qubits_len): # pragma: no cover def _qinst_lowering( - jax_ctx: mlir.LoweringRuleContext, *qubits_or_params: tuple, op=None, qubits_len=-1 + jax_ctx: mlir.LoweringRuleContext, *qubits_or_params: tuple, op=None, qubits_len=None ): ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True diff --git a/frontend/test/pytest/test_jax_primitives.py b/frontend/test/pytest/test_jax_primitives.py new file mode 100644 index 0000000000..fea9c6786b --- /dev/null +++ b/frontend/test/pytest/test_jax_primitives.py @@ -0,0 +1,36 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the JAX primitives module.""" + +import pytest + +from catalyst.jax_primitives import AbstractQbit, qinst_p + + +class TestQinstPrim: + """Test the quantum instruction primitive.""" + + def test_abstract_eval_no_len(self): + """Test that the number of qubits is properly deduced when not set automatically.""" + + qb0, qb1 = (AbstractQbit(),) * 2 + result = qinst_p.abstract_eval(qb0, qb1, op="GarbageOp")[0] + + assert len(result) == 2 + assert all(isinstance(r, AbstractQbit) for r in result) + + +if __name__ == "__main__": + pytest.main(["-x", __file__])