diff --git a/.dep-versions b/.dep-versions index 6b19ccd873..7c8cf2d3e2 100644 --- a/.dep-versions +++ b/.dep-versions @@ -3,3 +3,4 @@ jax=0.4.23 mhlo=4611968a5f6818e6bdfb82217b9e836e0400bba9 llvm=cd9a641613eddf25d4b25eaa96b2c393d401d42c enzyme=1beb98b51442d50652eaa3ffb9574f4720d611f1 +pennylane=95129a0d6365b48cb4acfa828ceb6a8532e47ef5 diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index 94add6de01..b0ab0b9cfa 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -390,4 +390,5 @@ jobs: run: | python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest -n auto python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --backend="lightning.kokkos" -n auto + python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/async_tests python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --runbraket=LOCAL -n auto diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index 576d92ae32..bf0a1b8662 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -400,4 +400,5 @@ jobs: run: | python${{ matrix.python_version }} -m pytest -v $GITHUB_WORKSPACE/frontend/test/pytest -n auto python${{ matrix.python_version }} -m pytest -v $GITHUB_WORKSPACE/frontend/test/pytest --backend="lightning.kokkos" -n auto + python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/async_tests python${{ matrix.python_version }} -m pytest -v $GITHUB_WORKSPACE/frontend/test/pytest --runbraket=LOCAL -n auto diff --git a/.github/workflows/build-wheel-macos-x86_64.yaml b/.github/workflows/build-wheel-macos-x86_64.yaml index 058157b501..ab830e9c3c 100644 --- a/.github/workflows/build-wheel-macos-x86_64.yaml +++ b/.github/workflows/build-wheel-macos-x86_64.yaml @@ -370,4 +370,5 @@ jobs: python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest -n auto # TODO: Uncomment after fixing https://github.com/PennyLaneAI/pennylane-lightning/issues/552 # python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --backend="lightning.kokkos" -n auto + python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/async_tests python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --runbraket=LOCAL -n auto diff --git a/.github/workflows/check-pl-compat.yaml b/.github/workflows/check-pl-compat.yaml index fddfe8d32b..523f826689 100644 --- a/.github/workflows/check-pl-compat.yaml +++ b/.github/workflows/check-pl-compat.yaml @@ -21,6 +21,8 @@ jobs: constants: name: "Set build matrix" uses: ./.github/workflows/constants.yaml + with: + use_release_tag: ${{ inputs.catalyst == 'stable' }} check-config: name: Build Configuration diff --git a/.github/workflows/constants.yaml b/.github/workflows/constants.yaml index 18bc4be5b8..eec25cef37 100644 --- a/.github/workflows/constants.yaml +++ b/.github/workflows/constants.yaml @@ -7,6 +7,10 @@ on: required: false default: false type: boolean + use_release_tag: + required: false + default: false + type: boolean outputs: llvm_version: description: "LLVM version" @@ -45,6 +49,11 @@ jobs: steps: - name: Checkout Catalyst repo uses: actions/checkout@v3 + with: + fetch-depth: 0 + - if: ${{ inputs.use_release_tag }} + run: | + git checkout $(git tag | sort -V | tail -1) - name: LLVM version id: llvm_version diff --git a/doc/changelog.md b/doc/changelog.md index c9c958bbf2..61dec29de4 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -40,6 +40,15 @@ f(2, MyClass(5)) # no re-compilation ``` +* Catalyst now supports executing tapes in CUDA-Quantum simulators. + [(#477)](https://github.com/PennyLaneAI/catalyst/pull/477) + [(#536)](https://github.com/PennyLaneAI/catalyst/pull/536) + + It has added the following devices: + * softwareq.qpp + * nvidia.statevec (with support for multi-gpu) + * nvidia.tensornet (with support for matrix product state) +

Improvements

* Catalyst will now remember previously compiled functions when the PyTree metadata of arguments @@ -304,6 +313,7 @@ * Handle run time exception in async qnodes. [(#447)](https://github.com/PennyLaneAI/catalyst/pull/447) + [(#510)](https://github.com/PennyLaneAI/catalyst/pull/510) This is done by: * changeing `llvm.call` to `llvm.invoke` diff --git a/frontend/catalyst/ag_primitives.py b/frontend/catalyst/ag_primitives.py index 7e2636484f..30bc5a0436 100644 --- a/frontend/catalyst/ag_primitives.py +++ b/frontend/catalyst/ag_primitives.py @@ -44,8 +44,8 @@ import catalyst from catalyst.ag_utils import AutoGraphError +from catalyst.jax_extras import DynamicJaxprTracer, ShapedArray from catalyst.tracing.contexts import EvaluationContext -from catalyst.utils.jax_extras import DynamicJaxprTracer, ShapedArray from catalyst.utils.patching import Patcher __all__ = [ diff --git a/frontend/catalyst/compiled_functions.py b/frontend/catalyst/compiled_functions.py index 4c0aca9ba7..ad1b9ca7e3 100644 --- a/frontend/catalyst/compiled_functions.py +++ b/frontend/catalyst/compiled_functions.py @@ -28,6 +28,7 @@ make_zero_d_memref_descriptor, ) +from catalyst.jax_extras import get_implicit_and_explicit_flat_args from catalyst.tracing.type_signatures import ( TypeCompatibility, filter_static_args, @@ -37,7 +38,6 @@ from catalyst.utils import wrapper # pylint: disable=no-name-in-module from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type from catalyst.utils.filesystem import Directory -from catalyst.utils.jax_extras import get_implicit_and_explicit_flat_args class SharedObjectManager: diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index fecae4a844..66d6067cb5 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -251,7 +251,7 @@ class LinkerDriver: _default_fallback_compilers = ["clang", "gcc", "c99", "c89", "cc"] @staticmethod - def get_default_flags(): + def get_default_flags(options): """Re-compute the path where the libraries exist. The use case for this is if someone is in a python jupyter notebook and @@ -312,6 +312,14 @@ def get_default_flags(): elif platform.system() == "Darwin": # pragma: nocover system_flags += ["-Wl,-arch_errors_fatal"] + # The exception handling mechanism requires linking against + # __gxx_personality_v0 which is either on -lstdc++ in + # or -lc++. We choose based on the operating system. + if options.async_qnodes and platform.system() == "Linux": # pragma: nocover + system_flags += ["-lstdc++"] + elif options.async_qnodes and platform.system() == "Darwin": # pragma: nocover + system_flags += ["-lc++"] + default_flags = [ "-shared", "-rdynamic", @@ -395,12 +403,12 @@ def run(infile, outfile=None, flags=None, fallback_compilers=None, options=None) """ if outfile is None: outfile = LinkerDriver.get_output_filename(infile) + if options is None: + options = CompileOptions() if flags is None: - flags = LinkerDriver.get_default_flags() + flags = LinkerDriver.get_default_flags(options) if fallback_compilers is None: fallback_compilers = LinkerDriver._default_fallback_compilers - if options is None: - options = CompileOptions() for compiler in LinkerDriver._available_compilers(fallback_compilers): success = LinkerDriver._attempt_link(compiler, flags, infile, outfile, options) if success: diff --git a/frontend/catalyst/cuda/__init__.py b/frontend/catalyst/cuda/__init__.py index 9da6115665..866d07c449 100644 --- a/frontend/catalyst/cuda/__init__.py +++ b/frontend/catalyst/cuda/__init__.py @@ -41,8 +41,7 @@ def wrap_fn(fn): class BaseCudaInstructionSet(qml.QubitDevice): """Base instruction set for CUDA-Quantum devices""" - # TODO: Once 0.35 is released, remove -dev suffix. - pennylane_requires = "0.35.0-dev" + pennylane_requires = ">=0.34" version = "0.1.0" author = "Xanadu, Inc." @@ -68,14 +67,12 @@ class BaseCudaInstructionSet(qml.QubitDevice): "RY", "RZ", "SWAP", - # "CSWAP", This is a bug in cuda-quantum. CSWAP is not exposed. + "CSWAP", ] observables = [] config = Path(__file__).parent / "cuda_quantum.toml" - def __init__(self, shots=None, wires=None, mps=False, multi_gpu=False): - self.mps = mps - self.multi_gpu = multi_gpu + def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) def apply(self, operations, **kwargs): @@ -88,25 +85,41 @@ def apply(self, operations, **kwargs): class SoftwareQQPP(BaseCudaInstructionSet): """Concrete device class for qpp-cpu""" - name = "SoftwareQ q++ simulator" short_name = "softwareq.qpp" + @property + def name(self): + """Target name""" + return "qpp-cpu" + class NvidiaCuStateVec(BaseCudaInstructionSet): """Concrete device class for CuStateVec""" - name = "CuStateVec" short_name = "nvidia.custatevec" def __init__(self, shots=None, wires=None, multi_gpu=False): # pragma: no cover - super().__init__(wires=wires, shots=shots, multi_gpu=multi_gpu) + self.multi_gpu = multi_gpu + super().__init__(wires=wires, shots=shots) + + @property + def name(self): # pragma: no cover + """Target name""" + option = "-mgpu" if self.multi_gpu else "" + return f"nvidia{option}" class NvidiaCuTensorNet(BaseCudaInstructionSet): """Concrete device class for CuTensorNet""" - name = "CuTensorNet" short_name = "nvidia.cutensornet" def __init__(self, shots=None, wires=None, mps=False): # pragma: no cover - super().__init__(wires=wires, shots=shots, mps=mps) + self.mps = mps + super().__init__(wires=wires, shots=shots) + + @property + def name(self): # pragma: no cover + """Target name""" + option = "-mps" if self.mps else "" + return f"tensornet{option}" diff --git a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py index 81775e9513..5557eb5f21 100644 --- a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py +++ b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py @@ -275,18 +275,11 @@ def change_device_to_cuda_device(ctx): device_name = qdevice_eqn.params.get("rtd_name") - # TODO(@erick-xanadu) as more devices become available - # map the names here. - # TODO(@erick-xanadu) why does the device instruction lists the whole - # name instead of a short name? - target_map = {"SoftwareQ q++ simulator": "qpp-cpu"} - target = target_map.get(device_name, device_name) - - if not target or not cudaq.has_target(target): - msg = f"Unavailable target {target}." # pragma: no cover + if not cudaq.has_target(device_name): + msg = f"Unavailable target {device_name}." # pragma: no cover raise ValueError(msg) - cudaq_target = cudaq.get_target(target) + cudaq_target = cudaq.get_target(device_name) cudaq.set_target(cudaq_target) # cudaq_make_kernel returns a multiple values depending on the arguments. @@ -411,7 +404,7 @@ def change_instruction(ctx, eqn): "RY": "ry", "RZ": "rz", "SWAP": "swap", - # "CSWAP": "cswap", Bug in CUDA quantum. CSWAP is not exposed. + "CSWAP": "cswap", # Other instructions that are missing: # ch # sdg @@ -847,12 +840,11 @@ def cudaq_backend_info(device): # We could also pass abstract arguments here in *args # the same way we do so in Catalyst. # But I think that is redundant now given make_jaxpr2 - _, jaxpr, _, out_tree = trace_to_jaxpr(func, static_args, abs_axes, *args) + jaxpr, out_treedef = trace_to_jaxpr(func, static_args, abs_axes, args, {}) # TODO(@erick-xanadu): # What about static_args? - # We could return _out_type2 as well - return jaxpr, out_tree + return jaxpr, out_treedef def interpret(fun): diff --git a/frontend/catalyst/cuda/cuda_quantum.toml b/frontend/catalyst/cuda/cuda_quantum.toml index 23eeb28cef..5625ce21e2 100644 --- a/frontend/catalyst/cuda/cuda_quantum.toml +++ b/frontend/catalyst/cuda/cuda_quantum.toml @@ -32,7 +32,7 @@ native = [ "RY", "RZ", "SWAP", - # "CSWAP", # Not exposed in CUDA quantum. + "CSWAP", ] # Operators that should be decomposed according to the algorithm used diff --git a/frontend/catalyst/cuda/primitives/__init__.py b/frontend/catalyst/cuda/primitives/__init__.py index e4d4627a42..43f6cb0af4 100644 --- a/frontend/catalyst/cuda/primitives/__init__.py +++ b/frontend/catalyst/cuda/primitives/__init__.py @@ -268,7 +268,7 @@ def cudaq_getstate(kernel): @cudaq_getstate_p.def_impl def cudaq_getstate_primitive_impl(kernel): """Concrete implementation.""" - return cudaq.get_state(kernel) + return jax.numpy.array(cudaq.get_state(kernel)) @cudaq_getstate_p.def_abstract_eval @@ -394,9 +394,17 @@ def cudaq_sample_impl(kernel, *args, shots_count=1000): So, let's perform a little conversion here. """ a_dict = cudaq.sample(kernel, *args, shots_count=shots_count) - lls = [[k] * v for k, v in a_dict.items()] - # Weirdly enough Catalyst returns this transposed. - return jax.numpy.atleast_2d(jax.numpy.array([int(l) for ls in lls for l in ls])).T + aggregate = [] + for bitstring, count in a_dict.items(): + # It is technically a bit array + # So we should use int(bit) + # But in Catalyst, these are floats. + # So we use floats. + bitarray = [float(bit) for bit in bitstring] + for _ in range(count): + aggregate.append(bitarray) + + return jax.numpy.array(aggregate) @cudaq_sample_p.def_abstract_eval diff --git a/frontend/catalyst/jax_extras/__init__.py b/frontend/catalyst/jax_extras/__init__.py new file mode 100644 index 0000000000..f3abfe5510 --- /dev/null +++ b/frontend/catalyst/jax_extras/__init__.py @@ -0,0 +1,56 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Catalyst additions to the Jax library """ + +from catalyst.jax_extras.lowering import custom_lower_jaxpr_to_module, jaxpr_to_mlir +from catalyst.jax_extras.patches import ( + _gather_shape_rule_dynamic, + _no_clean_up_dead_vars, + get_aval2, +) +from catalyst.jax_extras.tracing import ( + ClosedJaxpr, + DynshapedJaxpr, + DynamicJaxprTrace, + DynamicJaxprTracer, + Jaxpr, + PyTreeDef, + PyTreeRegistry, + ShapedArray, + ShapeDtypeStruct, + _abstractify, + _extract_implicit_args, + _initial_style_jaxpr, + _input_type_to_tracers, + convert_constvars_jaxpr, + convert_element_type, + deduce_avals, + eval_jaxpr, + get_implicit_and_explicit_flat_args, + infer_lambda_input_type, + initial_style_jaxprs_with_common_consts1, + initial_style_jaxprs_with_common_consts2, + make_jaxpr2, + make_jaxpr_effects, + new_dynamic_main2, + new_inner_tracer, + sort_eqns, + transient_jax_config, + tree_flatten, + tree_structure, + tree_unflatten, + treedef_is_leaf, + unzip2, + wrap_init, +) diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py new file mode 100644 index 0000000000..19ba1e3366 --- /dev/null +++ b/frontend/catalyst/jax_extras/lowering.py @@ -0,0 +1,153 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Jax extras module containing functions related to the StableHLO lowering """ + +from __future__ import annotations + +import jax +from jax._src.dispatch import jaxpr_replicas +from jax._src.effects import ordered_effects as jax_ordered_effects +from jax._src.interpreters.mlir import _module_name_regex +from jax._src.lax.lax import xla +from jax._src.sharding_impls import ReplicaAxisContext +from jax._src.source_info_util import new_name_stack +from jax._src.util import wrap_name +from jax.core import ClosedJaxpr +from jax.interpreters.mlir import ( + AxisContext, + LoweringParameters, + ModuleContext, + ir, + lower_jaxpr_to_fun, + lowerable_effects, +) + +from catalyst.utils.patching import Patcher + +# pylint: disable=protected-access + +__all__ = ("jaxpr_to_mlir", "custom_lower_jaxpr_to_module") + +from catalyst.jax_extras.patches import _no_clean_up_dead_vars, get_aval2 + + +def jaxpr_to_mlir(func_name, jaxpr): + """Lower a Jaxpr into an MLIR module. + + Args: + func_name(str): function name + jaxpr(Jaxpr): Jaxpr code to lower + + Returns: + module: the MLIR module corresponding to ``func`` + context: the MLIR context corresponding + """ + + with Patcher( + (jax._src.interpreters.partial_eval, "get_aval", get_aval2), + (jax._src.core, "clean_up_dead_vars", _no_clean_up_dead_vars), + ): + nrep = jaxpr_replicas(jaxpr) + effects = jax_ordered_effects.filter_in(jaxpr.effects) + axis_context = ReplicaAxisContext(xla.AxisEnv(nrep, (), ())) + name_stack = new_name_stack(wrap_name("ok", "jit")) + module, context = custom_lower_jaxpr_to_module( + func_name="jit_" + func_name, + module_name=func_name, + jaxpr=jaxpr, + effects=effects, + platform="cpu", + axis_context=axis_context, + name_stack=name_stack, + ) + + return module, context + + +# pylint: disable=too-many-arguments +def custom_lower_jaxpr_to_module( + func_name: str, + module_name: str, + jaxpr: ClosedJaxpr, + effects, + platform: str, + axis_context: AxisContext, + name_stack, + replicated_args=None, + arg_shardings=None, + result_shardings=None, +): + """Lowers a top-level jaxpr to an MHLO module. + + Handles the quirks of the argument/return value passing conventions of the + runtime. + + This function has been modified from its original form in the JAX project at + https://github.com/google/jax/blob/c4d590b1b640cc9fcfdbe91bf3fe34c47bcde917/jax/interpreters/mlir.py#L625version + released under the Apache License, Version 2.0, with the following copyright notice: + + Copyright 2021 The JAX Authors. + """ + + if any(lowerable_effects.filter_not_in(jaxpr.effects)): # pragma: no cover + raise ValueError(f"Cannot lower jaxpr with effects: {jaxpr.effects}") + + assert platform == "cpu" + assert arg_shardings is None + assert result_shardings is None + + # MHLO channels need to start at 1 + channel_iter = 1 + # Create a keepalives list that will be mutated during the lowering. + keepalives = [] + host_callbacks = [] + lowering_params = LoweringParameters() + ctx = ModuleContext( + backend_or_name=None, + platforms=[platform], + axis_context=axis_context, + name_stack=name_stack, + keepalives=keepalives, + channel_iterator=channel_iter, + host_callbacks=host_callbacks, + lowering_parameters=lowering_params, + ) + ctx.context.allow_unregistered_dialects = True + with ctx.context, ir.Location.unknown(ctx.context): + # register_dialect() + # Remove module name characters that XLA would alter. This ensures that + # XLA computation preserves the module name. + module_name = _module_name_regex.sub("_", module_name) + ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name) + lower_jaxpr_to_fun( + ctx, + func_name, + jaxpr, + effects, + public=True, + create_tokens=True, + replace_tokens_with_dummy=True, + replicated_args=replicated_args, + arg_shardings=arg_shardings, + result_shardings=result_shardings, + ) + + for op in ctx.module.body.operations: + func_name = str(op.name) + is_entry_point = func_name.startswith('"jit_') + if is_entry_point: + continue + op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage") + + return ctx.module, ctx.context diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py new file mode 100644 index 0000000000..266f38b33c --- /dev/null +++ b/frontend/catalyst/jax_extras/patches.py @@ -0,0 +1,182 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Jax extras module containing Jax patches """ + +# pylint: disable=too-many-arguments + +from __future__ import annotations + +import jax +from jax._src.lax.slicing import ( + _gather_shape_computation, + _is_sorted, + _no_duplicate_dims, + _rank, + _sorted_dims_in_range, +) +from jax.core import AbstractValue, Tracer, concrete_aval + +__all__ = ( + "get_aval2", + "_no_clean_up_dead_vars", + "_gather_shape_rule_dynamic", +) + + +def get_aval2(x): + """An extended version of `jax.core.get_aval` which also accepts AbstractValues.""" + # TODO: remove this patch when https://github.com/google/jax/pull/18579 is merged + if isinstance(x, AbstractValue): + return x + elif isinstance(x, Tracer): + return x.aval + else: + return concrete_aval(x) + + +def _no_clean_up_dead_vars(_eqn, _env, _last_used): + """A stub to workaround the Jax ``KeyError 'a'`` bug during the lowering of Jaxpr programs to + MLIR with the dynamic API enabled.""" + return None + + +def _gather_shape_rule_dynamic( + operand, + indices, + *, + dimension_numbers, + slice_sizes, + unique_indices, + indices_are_sorted, + mode, + fill_value, +): # pragma: no cover + """Validates the well-formedness of the arguments to Gather. Compared to the original version, + this implementation skips static shape checks if variable dimensions are used. + + This function has been modified from its original form in the JAX project at + https://github.com/google/jax/blob/88a60b808c1f91260cc9e75b9aa2508aae5bc9f9/jax/_src/lax/slicing.py#L1438 + version released under the Apache License, Version 2.0, with the following copyright notice: + + Copyright 2021 The JAX Authors. + TODO(@grwlf): delete once PR https://github.com/google/jax/pull/19083 has been merged + """ + # pylint: disable=unused-argument + # pylint: disable=too-many-branches + # pylint: disable=consider-using-enumerate + # pylint: disable=chained-comparison + offset_dims = dimension_numbers.offset_dims + collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + start_index_map = dimension_numbers.start_index_map + + # Note: in JAX, index_vector_dim is always computed as below, cf. the + # documentation of the GatherDimensionNumbers class. + index_vector_dim = _rank(indices) - 1 + + # This case should never happen in JAX, due to the implicit construction of + # index_vector_dim, but is included for completeness. + if _rank(indices) < index_vector_dim or index_vector_dim < 0: + raise TypeError( + f"Gather index leaf dimension must be within [0, rank(" + f"indices) + 1). rank(indices) is {_rank(indices)} and " + f"gather index leaf dimension is {index_vector_dim}." + ) + + # Start ValidateGatherDimensions + # In the error messages output by XLA, "offset_dims" is called "Output window + # dimensions" in error messages. For consistency's sake, our error messages + # stick to "offset_dims". + _is_sorted(offset_dims, "gather", "offset_dims") + _no_duplicate_dims(offset_dims, "gather", "offset_dims") + + output_offset_dim_count = len(offset_dims) + output_shape_rank = len(offset_dims) + _rank(indices) - 1 + + for i in range(output_offset_dim_count): + offset_dim = offset_dims[i] + if offset_dim < 0 or offset_dim >= output_shape_rank: + raise TypeError( + f"Offset dimension {i} in gather op is out of bounds; " + f"got {offset_dim}, but should have been in " + f"[0, {output_shape_rank})" + ) + + if len(start_index_map) != indices.shape[index_vector_dim]: + raise TypeError( + f"Gather op has {len(start_index_map)} elements in " + f"start_index_map and the bound of dimension " + f"{index_vector_dim=} of indices is " + f"{indices.shape[index_vector_dim]}. These two " + f"numbers must be equal." + ) + + for i in range(len(start_index_map)): + operand_dim_for_start_index_i = start_index_map[i] + if operand_dim_for_start_index_i < 0 or operand_dim_for_start_index_i >= _rank(operand): + raise TypeError( + f"Invalid start_index_map; domain is " + f"[0, {_rank(operand)}), got: " + f"{i}->{operand_dim_for_start_index_i}." + ) + + _no_duplicate_dims(start_index_map, "gather", "start_index_map") + + # _is_sorted and _sorted_dims_in_range are checked in the opposite order + # compared to the XLA implementation. In cases when the input is not sorted + # AND there are problematic collapsed_slice_dims, the error message will thus + # be different. + _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims") + _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims") + _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") + # End ValidateGatherDimensions + + if _rank(operand) != len(slice_sizes): + raise TypeError( + f"Gather op must have one slice size for every input " + f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " + f"input_shape.rank={_rank(operand)}" + ) + + if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): + raise TypeError( + f"All components of the offset index in a gather op must " + f"either be a offset dimension or explicitly collapsed; " + f"got len(slice_sizes)={len(slice_sizes)}, " + f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" + f"{collapsed_slice_dims}." + ) + + # This section contains a patch suggested to the upstream. + for i in range(len(slice_sizes)): + slice_size = slice_sizes[i] + corresponding_input_size = operand.shape[i] + + if jax.core.is_constant_dim(corresponding_input_size): + if not (slice_size >= 0 and corresponding_input_size >= slice_size): + raise TypeError( + f"Slice size at index {i} in gather op is out of range, " + f"must be within [0, {corresponding_input_size} + 1), " + f"got {slice_size}." + ) + + for i in range(len(collapsed_slice_dims)): + bound = slice_sizes[collapsed_slice_dims[i]] + if bound != 1: + raise TypeError( + f"Gather op can only collapse slice dims with bound 1, " + f"but bound is {bound} for index " + f"{collapsed_slice_dims[i]} at position {i}." + ) + + return _gather_shape_computation(indices, dimension_numbers, slice_sizes) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/jax_extras/tracing.py similarity index 57% rename from frontend/catalyst/utils/jax_extras.py rename to frontend/catalyst/jax_extras/tracing.py index c3ed6fab33..2fb510ad14 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/jax_extras/tracing.py @@ -1,4 +1,4 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. +# Copyright 2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""This module isolates utility functions that depend on JAX low-level internals -""" +""" Jax extras module containing functions related to the Python program tracing """ + from __future__ import annotations from contextlib import ExitStack, contextmanager @@ -22,8 +22,6 @@ from jax import ShapeDtypeStruct from jax._src import state, util from jax._src.core import _update_thread_local_jit_state -from jax._src.dispatch import jaxpr_replicas -from jax._src.effects import ordered_effects as jax_ordered_effects from jax._src.interpreters.mlir import _module_name_regex, register_lowering from jax._src.interpreters.partial_eval import ( _input_type_to_tracers, @@ -31,44 +29,21 @@ trace_to_jaxpr_dynamic2, ) from jax._src.lax.control_flow import _initial_style_jaxpr, _initial_style_open_jaxpr -from jax._src.lax.lax import _abstractify, xla +from jax._src.lax.lax import _abstractify from jax._src.lax.slicing import ( _argnum_weak_type, _gather_dtype_rule, _gather_lower, - _gather_shape_computation, - _is_sorted, - _no_duplicate_dims, - _rank, - _sorted_dims_in_range, standard_primitive, ) from jax._src.linear_util import annotate from jax._src.pjit import _extract_implicit_args, _flat_axes_specs -from jax._src.sharding_impls import ReplicaAxisContext from jax._src.source_info_util import current as jax_current -from jax._src.source_info_util import new_name_stack -from jax._src.util import partition_list, safe_map, unzip2, unzip3, wrap_name, wraps +from jax._src.util import partition_list, safe_map, unzip2, unzip3, wraps from jax.api_util import flatten_fun -from jax.core import AbstractValue, ClosedJaxpr, Jaxpr, JaxprEqn, MainTrace, OutputType +from jax.core import ClosedJaxpr, Jaxpr, JaxprEqn, MainTrace, OutputType from jax.core import Primitive as JaxprPrimitive -from jax.core import ( - ShapedArray, - Trace, - Tracer, - concrete_aval, - eval_jaxpr, - gensym, - thread_local_state, -) -from jax.interpreters.mlir import ( - AxisContext, - LoweringParameters, - ModuleContext, - ir, - lower_jaxpr_to_fun, - lowerable_effects, -) +from jax.core import ShapedArray, Trace, eval_jaxpr, gensym, thread_local_state from jax.interpreters.partial_eval import ( DynamicJaxprTrace, DynamicJaxprTracer, @@ -76,7 +51,7 @@ make_jaxpr_effects, ) from jax.lax import convert_element_type -from jax.linear_util import wrap_init +from jax.extend.linear_util import wrap_init from jax.tree_util import ( PyTreeDef, tree_flatten, @@ -86,6 +61,7 @@ ) from jaxlib.xla_extension import PyTreeRegistry +from catalyst.jax_extras.patches import _gather_shape_rule_dynamic, get_aval2 from catalyst.utils.patching import Patcher # pylint: disable=protected-access @@ -108,7 +84,7 @@ "_abstractify", "_initial_style_jaxpr", "_input_type_to_tracers", - "jaxpr_to_mlir", + "_module_name_regex", "make_jaxpr_effects", "make_jaxpr2", "new_dynamic_main2", @@ -347,6 +323,7 @@ def deduce_avals(f: Callable, args, kwargs): """Wraps the callable ``f`` into a WrappedFun container accepting collapsed flatten arguments and returning expanded flatten results. Calculate input abstract values and output_tree promise. The promise must be called after the resulting wrapped function is evaluated.""" + # TODO: deprecate in favor of `deduce_signatures` flat_args, in_tree = tree_flatten((args, kwargs)) abstracted_axes = None axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs) @@ -358,134 +335,6 @@ def deduce_avals(f: Callable, args, kwargs): return wffa, in_avals, keep_inputs, out_tree_promise -def get_aval2(x): - """An extended version of `jax.core.get_aval` which also accepts AbstractValues.""" - # TODO: remove this patch when https://github.com/google/jax/pull/18579 is merged - if isinstance(x, AbstractValue): - return x - elif isinstance(x, Tracer): - return x.aval - else: - return concrete_aval(x) - - -def _no_clean_up_dead_vars(_eqn, _env, _last_used): - """A stub to workaround the Jax ``KeyError 'a'`` bug during the lowering of Jaxpr programs to - MLIR with the dynamic API enabled.""" - return None - - -def jaxpr_to_mlir(func_name, jaxpr): - """Lower a Jaxpr into an MLIR module. - - Args: - func_name(str): function name - jaxpr(Jaxpr): Jaxpr code to lower - - Returns: - module: the MLIR module corresponding to ``func`` - context: the MLIR context corresponding - """ - - with Patcher( - (jax._src.interpreters.partial_eval, "get_aval", get_aval2), - (jax._src.core, "clean_up_dead_vars", _no_clean_up_dead_vars), - ): - nrep = jaxpr_replicas(jaxpr) - effects = jax_ordered_effects.filter_in(jaxpr.effects) - axis_context = ReplicaAxisContext(xla.AxisEnv(nrep, (), ())) - name_stack = new_name_stack(wrap_name("ok", "jit")) - module, context = custom_lower_jaxpr_to_module( - func_name="jit_" + func_name, - module_name=func_name, - jaxpr=jaxpr, - effects=effects, - platform="cpu", - axis_context=axis_context, - name_stack=name_stack, - ) - - return module, context - - -# pylint: disable=too-many-arguments -def custom_lower_jaxpr_to_module( - func_name: str, - module_name: str, - jaxpr: ClosedJaxpr, - effects, - platform: str, - axis_context: AxisContext, - name_stack, - replicated_args=None, - arg_shardings=None, - result_shardings=None, -): - """Lowers a top-level jaxpr to an MHLO module. - - Handles the quirks of the argument/return value passing conventions of the - runtime. - - This function has been modified from its original form in the JAX project at - https://github.com/google/jax/blob/c4d590b1b640cc9fcfdbe91bf3fe34c47bcde917/jax/interpreters/mlir.py#L625version - released under the Apache License, Version 2.0, with the following copyright notice: - - Copyright 2021 The JAX Authors. - """ - - if any(lowerable_effects.filter_not_in(jaxpr.effects)): # pragma: no cover - raise ValueError(f"Cannot lower jaxpr with effects: {jaxpr.effects}") - - assert platform == "cpu" - assert arg_shardings is None - assert result_shardings is None - - # MHLO channels need to start at 1 - channel_iter = 1 - # Create a keepalives list that will be mutated during the lowering. - keepalives = [] - host_callbacks = [] - lowering_params = LoweringParameters() - ctx = ModuleContext( - backend_or_name=None, - platforms=[platform], - axis_context=axis_context, - name_stack=name_stack, - keepalives=keepalives, - channel_iterator=channel_iter, - host_callbacks=host_callbacks, - lowering_parameters=lowering_params, - ) - ctx.context.allow_unregistered_dialects = True - with ctx.context, ir.Location.unknown(ctx.context): - # register_dialect() - # Remove module name characters that XLA would alter. This ensures that - # XLA computation preserves the module name. - module_name = _module_name_regex.sub("_", module_name) - ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name) - lower_jaxpr_to_fun( - ctx, - func_name, - jaxpr, - effects, - public=True, - create_tokens=True, - replace_tokens_with_dummy=True, - replicated_args=replicated_args, - arg_shardings=arg_shardings, - result_shardings=result_shardings, - ) - - for op in ctx.module.body.operations: - func_name = str(op.name) - is_entry_point = func_name.startswith('"jit_') - if is_entry_point: - continue - op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage") - - return ctx.module, ctx.context - - def new_inner_tracer(trace: DynamicJaxprTrace, aval) -> DynamicJaxprTracer: """Create a JAX tracer tracing an abstract value ``aval`, without specifying its source primitive.""" @@ -549,133 +398,3 @@ def make_jaxpr_f(*args, **kwargs): make_jaxpr_f.__name__ = f"make_jaxpr2({make_jaxpr2.__name__})" return make_jaxpr_f - - -def _gather_shape_rule_dynamic( - operand, - indices, - *, - dimension_numbers, - slice_sizes, - unique_indices, - indices_are_sorted, - mode, - fill_value, -): # pragma: no cover - """Validates the well-formedness of the arguments to Gather. Compared to the original version, - this implementation skips static shape checks if variable dimensions are used. - - This function has been modified from its original form in the JAX project at - https://github.com/google/jax/blob/88a60b808c1f91260cc9e75b9aa2508aae5bc9f9/jax/_src/lax/slicing.py#L1438 - version released under the Apache License, Version 2.0, with the following copyright notice: - - Copyright 2021 The JAX Authors. - """ - # pylint: disable=unused-argument - # pylint: disable=too-many-branches - # pylint: disable=consider-using-enumerate - # pylint: disable=chained-comparison - offset_dims = dimension_numbers.offset_dims - collapsed_slice_dims = dimension_numbers.collapsed_slice_dims - start_index_map = dimension_numbers.start_index_map - - # Note: in JAX, index_vector_dim is always computed as below, cf. the - # documentation of the GatherDimensionNumbers class. - index_vector_dim = _rank(indices) - 1 - - # This case should never happen in JAX, due to the implicit construction of - # index_vector_dim, but is included for completeness. - if _rank(indices) < index_vector_dim or index_vector_dim < 0: - raise TypeError( - f"Gather index leaf dimension must be within [0, rank(" - f"indices) + 1). rank(indices) is {_rank(indices)} and " - f"gather index leaf dimension is {index_vector_dim}." - ) - - # Start ValidateGatherDimensions - # In the error messages output by XLA, "offset_dims" is called "Output window - # dimensions" in error messages. For consistency's sake, our error messages - # stick to "offset_dims". - _is_sorted(offset_dims, "gather", "offset_dims") - _no_duplicate_dims(offset_dims, "gather", "offset_dims") - - output_offset_dim_count = len(offset_dims) - output_shape_rank = len(offset_dims) + _rank(indices) - 1 - - for i in range(output_offset_dim_count): - offset_dim = offset_dims[i] - if offset_dim < 0 or offset_dim >= output_shape_rank: - raise TypeError( - f"Offset dimension {i} in gather op is out of bounds; " - f"got {offset_dim}, but should have been in " - f"[0, {output_shape_rank})" - ) - - if len(start_index_map) != indices.shape[index_vector_dim]: - raise TypeError( - f"Gather op has {len(start_index_map)} elements in " - f"start_index_map and the bound of dimension " - f"{index_vector_dim=} of indices is " - f"{indices.shape[index_vector_dim]}. These two " - f"numbers must be equal." - ) - - for i in range(len(start_index_map)): - operand_dim_for_start_index_i = start_index_map[i] - if operand_dim_for_start_index_i < 0 or operand_dim_for_start_index_i >= _rank(operand): - raise TypeError( - f"Invalid start_index_map; domain is " - f"[0, {_rank(operand)}), got: " - f"{i}->{operand_dim_for_start_index_i}." - ) - - _no_duplicate_dims(start_index_map, "gather", "start_index_map") - - # _is_sorted and _sorted_dims_in_range are checked in the opposite order - # compared to the XLA implementation. In cases when the input is not sorted - # AND there are problematic collapsed_slice_dims, the error message will thus - # be different. - _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims") - _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims") - _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") - # End ValidateGatherDimensions - - if _rank(operand) != len(slice_sizes): - raise TypeError( - f"Gather op must have one slice size for every input " - f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " - f"input_shape.rank={_rank(operand)}" - ) - - if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): - raise TypeError( - f"All components of the offset index in a gather op must " - f"either be a offset dimension or explicitly collapsed; " - f"got len(slice_sizes)={len(slice_sizes)}, " - f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" - f"{collapsed_slice_dims}." - ) - - # This section contains a patch suggested to the upstream. - for i in range(len(slice_sizes)): - slice_size = slice_sizes[i] - corresponding_input_size = operand.shape[i] - - if jax.core.is_constant_dim(corresponding_input_size): - if not (slice_size >= 0 and corresponding_input_size >= slice_size): - raise TypeError( - f"Slice size at index {i} in gather op is out of range, " - f"must be within [0, {corresponding_input_size} + 1), " - f"got {slice_size}." - ) - - for i in range(len(collapsed_slice_dims)): - bound = slice_sizes[collapsed_slice_dims[i]] - if bound != 1: - raise TypeError( - f"Gather op can only collapse slice dims with bound 1, " - f"but bound is {bound} for index " - f"{collapsed_slice_dims[i]} at position {i}." - ) - - return _gather_shape_computation(indices, dimension_numbers, slice_sizes) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 4748a46a15..cd469a38f7 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -27,6 +27,28 @@ from pennylane.tape import QuantumTape import catalyst +from catalyst.jax_extras import ( + ClosedJaxpr, + DynshapedJaxpr, + DynamicJaxprTrace, + DynamicJaxprTracer, + PyTreeDef, + PyTreeRegistry, + ShapedArray, + _abstractify, + _input_type_to_tracers, + convert_element_type, + deduce_avals, + eval_jaxpr, + jaxpr_to_mlir, + make_jaxpr2, + sort_eqns, + transient_jax_config, + tree_flatten, + tree_structure, + tree_unflatten, + wrap_init, +) from catalyst.jax_primitives import ( AbstractQreg, compbasis_p, @@ -58,28 +80,6 @@ JaxTracingContext, ) from catalyst.utils.exceptions import CompileError -from catalyst.utils.jax_extras import ( - ClosedJaxpr, - DynamicJaxprTrace, - DynamicJaxprTracer, - DynshapedJaxpr, - PyTreeDef, - PyTreeRegistry, - ShapedArray, - _abstractify, - _input_type_to_tracers, - convert_element_type, - deduce_avals, - eval_jaxpr, - jaxpr_to_mlir, - make_jaxpr2, - sort_eqns, - transient_jax_config, - tree_flatten, - tree_structure, - tree_unflatten, - wrap_init, -) class Function: diff --git a/frontend/catalyst/pennylane_extensions.py b/frontend/catalyst/pennylane_extensions.py index 176c747775..e141439359 100644 --- a/frontend/catalyst/pennylane_extensions.py +++ b/frontend/catalyst/pennylane_extensions.py @@ -44,6 +44,21 @@ from pennylane.tape import QuantumTape import catalyst +from catalyst.jax_extras import ( # infer_output_type3, + ClosedJaxpr, + DynamicJaxprTracer, + Jaxpr, + ShapedArray, + _initial_style_jaxpr, + _input_type_to_tracers, + convert_constvars_jaxpr, + deduce_avals, + get_implicit_and_explicit_flat_args, + initial_style_jaxprs_with_common_consts1, + initial_style_jaxprs_with_common_consts2, + new_inner_tracer, + unzip2, +) from catalyst.jax_primitives import ( AbstractQreg, GradParams, @@ -65,7 +80,6 @@ HybridOp, HybridOpRegion, QRegPromise, - deduce_avals, has_nested_tapes, trace_quantum_function, trace_quantum_tape, @@ -78,20 +92,6 @@ JaxTracingContext, ) from catalyst.utils.exceptions import DifferentiableCompileError -from catalyst.utils.jax_extras import ( - ClosedJaxpr, - DynamicJaxprTracer, - Jaxpr, - ShapedArray, - _initial_style_jaxpr, - _input_type_to_tracers, - convert_constvars_jaxpr, - get_implicit_and_explicit_flat_args, - initial_style_jaxprs_with_common_consts1, - initial_style_jaxprs_with_common_consts2, - new_inner_tracer, - unzip2, -) from catalyst.utils.runtime import extract_backend_info, get_lib_path diff --git a/frontend/catalyst/tracing/contexts.py b/frontend/catalyst/tracing/contexts.py index 7f388a82fe..749c8b0f10 100644 --- a/frontend/catalyst/tracing/contexts.py +++ b/frontend/catalyst/tracing/contexts.py @@ -32,8 +32,8 @@ from jax.core import find_top_trace from pennylane.queuing import QueuingManager +from catalyst.jax_extras import new_dynamic_main2 from catalyst.utils.exceptions import CompileError -from catalyst.utils.jax_extras import new_dynamic_main2 class EvaluationMode(Enum): diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 3d730fe78e..0396a95592 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -27,7 +27,7 @@ from jax.api_util import shaped_abstractify from jax.tree_util import tree_flatten, tree_unflatten -from catalyst.utils.jax_extras import get_aval2 +from catalyst.jax_extras import get_aval2 from catalyst.utils.patching import Patcher diff --git a/frontend/test/pytest/test_cuda_integration.py b/frontend/test/pytest/test_cuda_integration.py index 2d12a6b89e..910eb21ab6 100644 --- a/frontend/test/pytest/test_cuda_integration.py +++ b/frontend/test/pytest/test_cuda_integration.py @@ -27,6 +27,7 @@ # when we are running kokkos. Importing CUDA before running any kokkos # kernel polutes the environment and will create a segfault. # pylint: disable=import-outside-toplevel +# pylint: disable=too-many-public-methods @pytest.mark.cuda @@ -49,13 +50,12 @@ def circuit_foo(): def test_qjit_cuda_remove_host_context(self): """Test removing the host context.""" - from catalyst.cuda import SoftwareQQPP from catalyst.cuda.catalyst_to_cuda_interpreter import ( QJIT_CUDAQ, remove_host_context, ) - @qml.qnode(SoftwareQQPP(wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit_foo(): return qml.state() @@ -65,10 +65,9 @@ def circuit_foo(): def test_qjit_catalyst_to_cuda_jaxpr(self): """Assert that catalyst_to_cuda returns something.""" - from catalyst.cuda import SoftwareQQPP from catalyst.cuda.catalyst_to_cuda_interpreter import interpret - @qml.qnode(SoftwareQQPP(wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit_foo(): return qml.state() @@ -78,12 +77,11 @@ def circuit_foo(): def test_measurement_return(self): """Test the measurement code is added.""" - from catalyst.cuda import SoftwareQQPP from catalyst.cuda.catalyst_to_cuda_interpreter import interpret with pytest.raises(NotImplementedError, match="cannot return measurements directly"): - @qml.qnode(SoftwareQQPP(wires=1, shots=30)) + @qml.qnode(qml.device("softwareq.qpp", wires=1, shots=30)) def circuit(): qml.RX(jnp.pi / 4, wires=[0]) return measure(0) @@ -93,10 +91,9 @@ def circuit(): def test_measurement_side_effect(self): """Test the measurement code is added.""" - from catalyst.cuda import SoftwareQQPP from catalyst.cuda.catalyst_to_cuda_interpreter import interpret - @qml.qnode(SoftwareQQPP(wires=1, shots=30)) + @qml.qnode(qml.device("softwareq.qpp", wires=1, shots=30)) def circuit(): qml.RX(jnp.pi / 4, wires=[0]) measure(0) @@ -106,9 +103,8 @@ def circuit(): def test_pytrees(self): """Test that we can return a dictionary.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit_a(a): qml.RX(a, wires=[0]) return {"a": qml.state()} @@ -126,9 +122,8 @@ def circuit_b(a): def test_cuda_device(self): """Test SoftwareQQPP.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit(a): qml.RX(a, wires=[0]) return qml.state() @@ -146,9 +141,8 @@ def circuit_lightning(a): def test_samples(self): """Test SoftwareQQPP.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=1, shots=100)) + @qml.qnode(qml.device("softwareq.qpp", wires=1, shots=100)) def circuit(a): qml.RX(a, wires=[0]) return qml.sample() @@ -166,9 +160,8 @@ def circuit_lightning(a): def test_counts(self): """Test SoftwareQQPP.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=1, shots=100)) + @qml.qnode(qml.device("softwareq.qpp", wires=1, shots=100)) def circuit(a): qml.RX(a, wires=[0]) return qml.counts() @@ -186,9 +179,8 @@ def circuit_lightning(a): def test_qjit_cuda_device(self): """Test SoftwareQQPP.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit(a): qml.RX(a, wires=[0]) return qml.state() @@ -206,9 +198,8 @@ def circuit_lightning(a): def test_abstract_variable(self): """Test abstract variable.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit(a: float): qml.RX(a, wires=[0]) return qml.state() @@ -226,9 +217,8 @@ def circuit_lightning(a): def test_arithmetic(self): """Test arithmetic.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit(a): qml.RX(a / 2, wires=[0]) return qml.state() @@ -246,9 +236,8 @@ def circuit_lightning(a): def test_multiple_values(self): """Test multiple_values.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit(params): x, y = jax.numpy.array_split(params, 2) qml.RX(x[0], wires=[0]) @@ -272,7 +261,7 @@ def circuit_lightning(params): def test_cuda_device_entry_point(self): """Test the entry point for SoftwareQQPP""" - @qml.qnode(qml.device("software.qpp", wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit(a): qml.RX(a, wires=[0]) return {"a": qml.state()} @@ -293,7 +282,7 @@ def test_cuda_device_entry_point_compiler(self): """Test the entry point for cudaq.qjit""" @qml.qjit(compiler="cuda_quantum") - @qml.qnode(qml.device("cudaq", wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit(a): qml.RX(a, wires=[0]) return {"a": qml.state()} @@ -302,9 +291,8 @@ def circuit(a): def test_expval(self): """Test multiple_values.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=1)) + @qml.qnode(qml.device("softwareq.qpp", wires=1)) def circuit(): qml.RX(jnp.pi / 2, wires=[0]) return qml.expval(qml.PauliZ(0)) @@ -322,9 +310,8 @@ def circuit_catalyst(): def test_expval_2(self): """Test multiple_values.""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=2)) + @qml.qnode(qml.device("softwareq.qpp", wires=2)) def circuit(): qml.RY(jnp.pi / 4, wires=[1]) return qml.expval(qml.PauliZ(1) + qml.PauliX(1)) @@ -343,9 +330,7 @@ def circuit_catalyst(): def test_adjoint(self): """Test adjoint.""" - from catalyst.cuda import SoftwareQQPP - - @qml.qnode(SoftwareQQPP(wires=2)) + @qml.qnode(qml.device("softwareq.qpp", wires=2)) def circuit(): def f(theta): qml.RX(theta / 23, wires=[0]) @@ -380,9 +365,7 @@ def f(theta): def test_control_ry(self): """Test control ry.""" - from catalyst.cuda import SoftwareQQPP - - @qml.qnode(SoftwareQQPP(wires=2)) + @qml.qnode(qml.device("softwareq.qpp", wires=2)) def circuit(): qml.Hadamard(wires=[0]) qml.CRY(jnp.pi / 2, wires=[0, 1]) @@ -403,9 +386,7 @@ def circuit_catalyst(): def test_swap(self): """Test swap.""" - from catalyst.cuda import SoftwareQQPP - - @qml.qnode(SoftwareQQPP(wires=2)) + @qml.qnode(qml.device("softwareq.qpp", wires=2)) def circuit(): qml.RX(jnp.pi / 3, wires=[0]) qml.SWAP(wires=[0, 1]) @@ -426,9 +407,7 @@ def circuit_catalyst(): def test_entanglement(self): """Test swap.""" - from catalyst.cuda import SoftwareQQPP - - @qml.qnode(SoftwareQQPP(wires=2)) + @qml.qnode(qml.device("softwareq.qpp", wires=2)) def circuit(): qml.Hadamard(wires=[0]) qml.CNOT(wires=[0, 1]) @@ -446,11 +425,47 @@ def circuit_catalyst(): expected = catalyst_compiled() assert_allclose(expected, observed) + def test_cswap(self): + """Test cswap.""" + + @qml.qnode(qml.device("softwareq.qpp", wires=3)) + def circuit(): + qml.Hadamard(wires=[0]) + qml.RX(jnp.pi / 7, wires=[1]) + qml.CSWAP(wires=[0, 1, 2]) + return qml.state() + + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit_catalyst(): + qml.Hadamard(wires=[0]) + qml.RX(jnp.pi / 7, wires=[1]) + qml.CSWAP(wires=[0, 1, 2]) + return qml.state() + + cuda_compiled = catalyst.cuda.qjit(circuit) + observed = cuda_compiled() + catalyst_compiled = qjit(circuit_catalyst) + expected = catalyst_compiled() + assert_allclose(expected, observed) + + def test_state_is_jax_array(self): + """Test return type for state.""" + + @qml.qnode(qml.device("softwareq.qpp", wires=3)) + def circuit(): + qml.Hadamard(wires=[0]) + qml.RX(jnp.pi / 7, wires=[1]) + qml.CSWAP(wires=[0, 1, 2]) + return qml.state() + + cuda_compiled = catalyst.cuda.qjit(circuit) + observed = cuda_compiled() + assert isinstance(observed, jax.Array) + def test_error_message_using_host_context(self): """Test error message""" - from catalyst.cuda import SoftwareQQPP - @qml.qnode(SoftwareQQPP(wires=2)) + @qml.qnode(qml.device("softwareq.qpp", wires=2)) def circuit(x): qml.Hadamard(wires=[0]) qml.CNOT(wires=[0, 1]) @@ -464,6 +479,28 @@ def wrapper(y): with pytest.raises(CompileError, match="Cannot translate tapes with context"): catalyst.cuda.qjit(wrapper)(1.0) + def test_samples(self): + """Samples with more than one wire.""" + + from catalyst.cuda import qjit as cjit + + @qjit + @qml.qnode(qml.device("lightning.qubit", wires=2, shots=10)) + def circuit1(a): + qml.RX(a, wires=0) + return qml.sample() + + expected = circuit1(3.14) + + @cjit + @qml.qnode(qml.device("softwareq.qpp", wires=2, shots=10)) + def circuit2(a): + qml.RX(a, wires=0) + return qml.sample() + + observed = circuit2(3.14) + assert_allclose(expected, observed) + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_jax_config.py b/frontend/test/pytest/test_jax_config.py index 0ae72c5a46..498367521c 100644 --- a/frontend/test/pytest/test_jax_config.py +++ b/frontend/test/pytest/test_jax_config.py @@ -16,7 +16,7 @@ import jax -from catalyst.utils.jax_extras import transient_jax_config +from catalyst.jax_extras import transient_jax_config def test_transient_jax_config(): diff --git a/setup.py b/setup.py index c93d92198e..4f9f35162d 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ import platform import subprocess from distutils import sysconfig -from os import environ, path +from os import path import numpy as np from pybind11.setup_helpers import intree_extensions @@ -33,43 +33,35 @@ version = f.readlines()[-1].split()[-1].strip("\"'") with open(".dep-versions") as f: - jax_version = [line[4:].strip() for line in f.readlines() if "jax=" in line][0] + lines = f.readlines() + jax_version = [line[4:].strip() for line in lines if "jax=" in line][0] + pl_str = "pennylane=" + pl_str_length = len(pl_str) + pl_version = [line[pl_str_length:].strip() for line in lines if pl_str in line][0] -pl_version = environ.get("PL_VERSION", ">=0.32,<=0.34") requirements = [ - f"pennylane{pl_version}", + f"pennylane @ git+https://github.com/pennylaneai/pennylane@{pl_version}", f"jax=={jax_version}", f"jaxlib=={jax_version}", "tomlkit;python_version<'3.11'", "scipy", ] -# TODO: Once PL version 0.35 is released: -# * remove this special handling -# * make pennylane>=0.35 a requirement -# * Close this ticket https://github.com/PennyLaneAI/catalyst/issues/494 -one_compiler_per_distribution = pl_version == ">=0.32,<=0.34" -if one_compiler_per_distribution: - entry_points = { - "pennylane.plugins": "softwareq.qpp = catalyst.cuda:SoftwareQQPP", - "pennylane.compilers": [ - "context = catalyst.tracing.contexts:EvaluationContext", - "ops = catalyst:pennylane_extensions", - "qjit = catalyst:qjit", - ], - } -else: - entry_points = { - "pennylane.plugins": "softwareq.qpp = catalyst.cuda:SoftwareQQPP", - "pennylane.compilers": [ - "catalyst.context = catalyst.tracing.contexts:EvaluationContext", - "catalyst.ops = catalyst:pennylane_extensions", - "catalyst.qjit = catalyst:qjit", - "cuda_quantum.context = catalyst.cuda:EvaluationContext", - "cuda_quantum.ops = catalyst.cuda:pennylane_extensions", - "cuda_quantum.qjit = catalyst.cuda:qjit", - ], - } +entry_points = { + "pennylane.plugins": [ + "softwareq.qpp = catalyst.cuda:SoftwareQQPP", + "nvidia.statevec = catalyst.cuda:NvidiaCuStateVec", + "nvidia.tensornet = catalyst.cuda:NvidiaCuTensorNet", + ], + "pennylane.compilers": [ + "catalyst.context = catalyst.tracing.contexts:EvaluationContext", + "catalyst.ops = catalyst:pennylane_extensions", + "catalyst.qjit = catalyst:qjit", + "cuda_quantum.context = catalyst.tracing.contexts:EvaluationContext", + "cuda_quantum.ops = catalyst:pennylane_extensions", + "cuda_quantum.qjit = catalyst.cuda:qjit", + ], +} classifiers = [ "Environment :: Console",