diff --git a/.dep-versions b/.dep-versions
index 6b19ccd873..7c8cf2d3e2 100644
--- a/.dep-versions
+++ b/.dep-versions
@@ -3,3 +3,4 @@ jax=0.4.23
mhlo=4611968a5f6818e6bdfb82217b9e836e0400bba9
llvm=cd9a641613eddf25d4b25eaa96b2c393d401d42c
enzyme=1beb98b51442d50652eaa3ffb9574f4720d611f1
+pennylane=95129a0d6365b48cb4acfa828ceb6a8532e47ef5
diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml
index 94add6de01..b0ab0b9cfa 100644
--- a/.github/workflows/build-wheel-linux-x86_64.yaml
+++ b/.github/workflows/build-wheel-linux-x86_64.yaml
@@ -390,4 +390,5 @@ jobs:
run: |
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest -n auto
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --backend="lightning.kokkos" -n auto
+ python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/async_tests
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --runbraket=LOCAL -n auto
diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml
index 576d92ae32..bf0a1b8662 100644
--- a/.github/workflows/build-wheel-macos-arm64.yaml
+++ b/.github/workflows/build-wheel-macos-arm64.yaml
@@ -400,4 +400,5 @@ jobs:
run: |
python${{ matrix.python_version }} -m pytest -v $GITHUB_WORKSPACE/frontend/test/pytest -n auto
python${{ matrix.python_version }} -m pytest -v $GITHUB_WORKSPACE/frontend/test/pytest --backend="lightning.kokkos" -n auto
+ python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/async_tests
python${{ matrix.python_version }} -m pytest -v $GITHUB_WORKSPACE/frontend/test/pytest --runbraket=LOCAL -n auto
diff --git a/.github/workflows/build-wheel-macos-x86_64.yaml b/.github/workflows/build-wheel-macos-x86_64.yaml
index 058157b501..ab830e9c3c 100644
--- a/.github/workflows/build-wheel-macos-x86_64.yaml
+++ b/.github/workflows/build-wheel-macos-x86_64.yaml
@@ -370,4 +370,5 @@ jobs:
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest -n auto
# TODO: Uncomment after fixing https://github.com/PennyLaneAI/pennylane-lightning/issues/552
# python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --backend="lightning.kokkos" -n auto
+ python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/async_tests
python${{ matrix.python_version }} -m pytest $GITHUB_WORKSPACE/frontend/test/pytest --runbraket=LOCAL -n auto
diff --git a/.github/workflows/check-pl-compat.yaml b/.github/workflows/check-pl-compat.yaml
index fddfe8d32b..523f826689 100644
--- a/.github/workflows/check-pl-compat.yaml
+++ b/.github/workflows/check-pl-compat.yaml
@@ -21,6 +21,8 @@ jobs:
constants:
name: "Set build matrix"
uses: ./.github/workflows/constants.yaml
+ with:
+ use_release_tag: ${{ inputs.catalyst == 'stable' }}
check-config:
name: Build Configuration
diff --git a/.github/workflows/constants.yaml b/.github/workflows/constants.yaml
index 18bc4be5b8..eec25cef37 100644
--- a/.github/workflows/constants.yaml
+++ b/.github/workflows/constants.yaml
@@ -7,6 +7,10 @@ on:
required: false
default: false
type: boolean
+ use_release_tag:
+ required: false
+ default: false
+ type: boolean
outputs:
llvm_version:
description: "LLVM version"
@@ -45,6 +49,11 @@ jobs:
steps:
- name: Checkout Catalyst repo
uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+ - if: ${{ inputs.use_release_tag }}
+ run: |
+ git checkout $(git tag | sort -V | tail -1)
- name: LLVM version
id: llvm_version
diff --git a/doc/changelog.md b/doc/changelog.md
index c9c958bbf2..61dec29de4 100644
--- a/doc/changelog.md
+++ b/doc/changelog.md
@@ -40,6 +40,15 @@
f(2, MyClass(5)) # no re-compilation
```
+* Catalyst now supports executing tapes in CUDA-Quantum simulators.
+ [(#477)](https://github.com/PennyLaneAI/catalyst/pull/477)
+ [(#536)](https://github.com/PennyLaneAI/catalyst/pull/536)
+
+ It has added the following devices:
+ * softwareq.qpp
+ * nvidia.statevec (with support for multi-gpu)
+ * nvidia.tensornet (with support for matrix product state)
+
Improvements
* Catalyst will now remember previously compiled functions when the PyTree metadata of arguments
@@ -304,6 +313,7 @@
* Handle run time exception in async qnodes.
[(#447)](https://github.com/PennyLaneAI/catalyst/pull/447)
+ [(#510)](https://github.com/PennyLaneAI/catalyst/pull/510)
This is done by:
* changeing `llvm.call` to `llvm.invoke`
diff --git a/frontend/catalyst/ag_primitives.py b/frontend/catalyst/ag_primitives.py
index 7e2636484f..30bc5a0436 100644
--- a/frontend/catalyst/ag_primitives.py
+++ b/frontend/catalyst/ag_primitives.py
@@ -44,8 +44,8 @@
import catalyst
from catalyst.ag_utils import AutoGraphError
+from catalyst.jax_extras import DynamicJaxprTracer, ShapedArray
from catalyst.tracing.contexts import EvaluationContext
-from catalyst.utils.jax_extras import DynamicJaxprTracer, ShapedArray
from catalyst.utils.patching import Patcher
__all__ = [
diff --git a/frontend/catalyst/compiled_functions.py b/frontend/catalyst/compiled_functions.py
index 4c0aca9ba7..ad1b9ca7e3 100644
--- a/frontend/catalyst/compiled_functions.py
+++ b/frontend/catalyst/compiled_functions.py
@@ -28,6 +28,7 @@
make_zero_d_memref_descriptor,
)
+from catalyst.jax_extras import get_implicit_and_explicit_flat_args
from catalyst.tracing.type_signatures import (
TypeCompatibility,
filter_static_args,
@@ -37,7 +38,6 @@
from catalyst.utils import wrapper # pylint: disable=no-name-in-module
from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type
from catalyst.utils.filesystem import Directory
-from catalyst.utils.jax_extras import get_implicit_and_explicit_flat_args
class SharedObjectManager:
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index fecae4a844..66d6067cb5 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -251,7 +251,7 @@ class LinkerDriver:
_default_fallback_compilers = ["clang", "gcc", "c99", "c89", "cc"]
@staticmethod
- def get_default_flags():
+ def get_default_flags(options):
"""Re-compute the path where the libraries exist.
The use case for this is if someone is in a python jupyter notebook and
@@ -312,6 +312,14 @@ def get_default_flags():
elif platform.system() == "Darwin": # pragma: nocover
system_flags += ["-Wl,-arch_errors_fatal"]
+ # The exception handling mechanism requires linking against
+ # __gxx_personality_v0 which is either on -lstdc++ in
+ # or -lc++. We choose based on the operating system.
+ if options.async_qnodes and platform.system() == "Linux": # pragma: nocover
+ system_flags += ["-lstdc++"]
+ elif options.async_qnodes and platform.system() == "Darwin": # pragma: nocover
+ system_flags += ["-lc++"]
+
default_flags = [
"-shared",
"-rdynamic",
@@ -395,12 +403,12 @@ def run(infile, outfile=None, flags=None, fallback_compilers=None, options=None)
"""
if outfile is None:
outfile = LinkerDriver.get_output_filename(infile)
+ if options is None:
+ options = CompileOptions()
if flags is None:
- flags = LinkerDriver.get_default_flags()
+ flags = LinkerDriver.get_default_flags(options)
if fallback_compilers is None:
fallback_compilers = LinkerDriver._default_fallback_compilers
- if options is None:
- options = CompileOptions()
for compiler in LinkerDriver._available_compilers(fallback_compilers):
success = LinkerDriver._attempt_link(compiler, flags, infile, outfile, options)
if success:
diff --git a/frontend/catalyst/cuda/__init__.py b/frontend/catalyst/cuda/__init__.py
index 9da6115665..866d07c449 100644
--- a/frontend/catalyst/cuda/__init__.py
+++ b/frontend/catalyst/cuda/__init__.py
@@ -41,8 +41,7 @@ def wrap_fn(fn):
class BaseCudaInstructionSet(qml.QubitDevice):
"""Base instruction set for CUDA-Quantum devices"""
- # TODO: Once 0.35 is released, remove -dev suffix.
- pennylane_requires = "0.35.0-dev"
+ pennylane_requires = ">=0.34"
version = "0.1.0"
author = "Xanadu, Inc."
@@ -68,14 +67,12 @@ class BaseCudaInstructionSet(qml.QubitDevice):
"RY",
"RZ",
"SWAP",
- # "CSWAP", This is a bug in cuda-quantum. CSWAP is not exposed.
+ "CSWAP",
]
observables = []
config = Path(__file__).parent / "cuda_quantum.toml"
- def __init__(self, shots=None, wires=None, mps=False, multi_gpu=False):
- self.mps = mps
- self.multi_gpu = multi_gpu
+ def __init__(self, shots=None, wires=None):
super().__init__(wires=wires, shots=shots)
def apply(self, operations, **kwargs):
@@ -88,25 +85,41 @@ def apply(self, operations, **kwargs):
class SoftwareQQPP(BaseCudaInstructionSet):
"""Concrete device class for qpp-cpu"""
- name = "SoftwareQ q++ simulator"
short_name = "softwareq.qpp"
+ @property
+ def name(self):
+ """Target name"""
+ return "qpp-cpu"
+
class NvidiaCuStateVec(BaseCudaInstructionSet):
"""Concrete device class for CuStateVec"""
- name = "CuStateVec"
short_name = "nvidia.custatevec"
def __init__(self, shots=None, wires=None, multi_gpu=False): # pragma: no cover
- super().__init__(wires=wires, shots=shots, multi_gpu=multi_gpu)
+ self.multi_gpu = multi_gpu
+ super().__init__(wires=wires, shots=shots)
+
+ @property
+ def name(self): # pragma: no cover
+ """Target name"""
+ option = "-mgpu" if self.multi_gpu else ""
+ return f"nvidia{option}"
class NvidiaCuTensorNet(BaseCudaInstructionSet):
"""Concrete device class for CuTensorNet"""
- name = "CuTensorNet"
short_name = "nvidia.cutensornet"
def __init__(self, shots=None, wires=None, mps=False): # pragma: no cover
- super().__init__(wires=wires, shots=shots, mps=mps)
+ self.mps = mps
+ super().__init__(wires=wires, shots=shots)
+
+ @property
+ def name(self): # pragma: no cover
+ """Target name"""
+ option = "-mps" if self.mps else ""
+ return f"tensornet{option}"
diff --git a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
index 81775e9513..5557eb5f21 100644
--- a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
+++ b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
@@ -275,18 +275,11 @@ def change_device_to_cuda_device(ctx):
device_name = qdevice_eqn.params.get("rtd_name")
- # TODO(@erick-xanadu) as more devices become available
- # map the names here.
- # TODO(@erick-xanadu) why does the device instruction lists the whole
- # name instead of a short name?
- target_map = {"SoftwareQ q++ simulator": "qpp-cpu"}
- target = target_map.get(device_name, device_name)
-
- if not target or not cudaq.has_target(target):
- msg = f"Unavailable target {target}." # pragma: no cover
+ if not cudaq.has_target(device_name):
+ msg = f"Unavailable target {device_name}." # pragma: no cover
raise ValueError(msg)
- cudaq_target = cudaq.get_target(target)
+ cudaq_target = cudaq.get_target(device_name)
cudaq.set_target(cudaq_target)
# cudaq_make_kernel returns a multiple values depending on the arguments.
@@ -411,7 +404,7 @@ def change_instruction(ctx, eqn):
"RY": "ry",
"RZ": "rz",
"SWAP": "swap",
- # "CSWAP": "cswap", Bug in CUDA quantum. CSWAP is not exposed.
+ "CSWAP": "cswap",
# Other instructions that are missing:
# ch
# sdg
@@ -847,12 +840,11 @@ def cudaq_backend_info(device):
# We could also pass abstract arguments here in *args
# the same way we do so in Catalyst.
# But I think that is redundant now given make_jaxpr2
- _, jaxpr, _, out_tree = trace_to_jaxpr(func, static_args, abs_axes, *args)
+ jaxpr, out_treedef = trace_to_jaxpr(func, static_args, abs_axes, args, {})
# TODO(@erick-xanadu):
# What about static_args?
- # We could return _out_type2 as well
- return jaxpr, out_tree
+ return jaxpr, out_treedef
def interpret(fun):
diff --git a/frontend/catalyst/cuda/cuda_quantum.toml b/frontend/catalyst/cuda/cuda_quantum.toml
index 23eeb28cef..5625ce21e2 100644
--- a/frontend/catalyst/cuda/cuda_quantum.toml
+++ b/frontend/catalyst/cuda/cuda_quantum.toml
@@ -32,7 +32,7 @@ native = [
"RY",
"RZ",
"SWAP",
- # "CSWAP", # Not exposed in CUDA quantum.
+ "CSWAP",
]
# Operators that should be decomposed according to the algorithm used
diff --git a/frontend/catalyst/cuda/primitives/__init__.py b/frontend/catalyst/cuda/primitives/__init__.py
index e4d4627a42..43f6cb0af4 100644
--- a/frontend/catalyst/cuda/primitives/__init__.py
+++ b/frontend/catalyst/cuda/primitives/__init__.py
@@ -268,7 +268,7 @@ def cudaq_getstate(kernel):
@cudaq_getstate_p.def_impl
def cudaq_getstate_primitive_impl(kernel):
"""Concrete implementation."""
- return cudaq.get_state(kernel)
+ return jax.numpy.array(cudaq.get_state(kernel))
@cudaq_getstate_p.def_abstract_eval
@@ -394,9 +394,17 @@ def cudaq_sample_impl(kernel, *args, shots_count=1000):
So, let's perform a little conversion here.
"""
a_dict = cudaq.sample(kernel, *args, shots_count=shots_count)
- lls = [[k] * v for k, v in a_dict.items()]
- # Weirdly enough Catalyst returns this transposed.
- return jax.numpy.atleast_2d(jax.numpy.array([int(l) for ls in lls for l in ls])).T
+ aggregate = []
+ for bitstring, count in a_dict.items():
+ # It is technically a bit array
+ # So we should use int(bit)
+ # But in Catalyst, these are floats.
+ # So we use floats.
+ bitarray = [float(bit) for bit in bitstring]
+ for _ in range(count):
+ aggregate.append(bitarray)
+
+ return jax.numpy.array(aggregate)
@cudaq_sample_p.def_abstract_eval
diff --git a/frontend/catalyst/jax_extras/__init__.py b/frontend/catalyst/jax_extras/__init__.py
new file mode 100644
index 0000000000..f3abfe5510
--- /dev/null
+++ b/frontend/catalyst/jax_extras/__init__.py
@@ -0,0 +1,56 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Catalyst additions to the Jax library """
+
+from catalyst.jax_extras.lowering import custom_lower_jaxpr_to_module, jaxpr_to_mlir
+from catalyst.jax_extras.patches import (
+ _gather_shape_rule_dynamic,
+ _no_clean_up_dead_vars,
+ get_aval2,
+)
+from catalyst.jax_extras.tracing import (
+ ClosedJaxpr,
+ DynshapedJaxpr,
+ DynamicJaxprTrace,
+ DynamicJaxprTracer,
+ Jaxpr,
+ PyTreeDef,
+ PyTreeRegistry,
+ ShapedArray,
+ ShapeDtypeStruct,
+ _abstractify,
+ _extract_implicit_args,
+ _initial_style_jaxpr,
+ _input_type_to_tracers,
+ convert_constvars_jaxpr,
+ convert_element_type,
+ deduce_avals,
+ eval_jaxpr,
+ get_implicit_and_explicit_flat_args,
+ infer_lambda_input_type,
+ initial_style_jaxprs_with_common_consts1,
+ initial_style_jaxprs_with_common_consts2,
+ make_jaxpr2,
+ make_jaxpr_effects,
+ new_dynamic_main2,
+ new_inner_tracer,
+ sort_eqns,
+ transient_jax_config,
+ tree_flatten,
+ tree_structure,
+ tree_unflatten,
+ treedef_is_leaf,
+ unzip2,
+ wrap_init,
+)
diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py
new file mode 100644
index 0000000000..19ba1e3366
--- /dev/null
+++ b/frontend/catalyst/jax_extras/lowering.py
@@ -0,0 +1,153 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Jax extras module containing functions related to the StableHLO lowering """
+
+from __future__ import annotations
+
+import jax
+from jax._src.dispatch import jaxpr_replicas
+from jax._src.effects import ordered_effects as jax_ordered_effects
+from jax._src.interpreters.mlir import _module_name_regex
+from jax._src.lax.lax import xla
+from jax._src.sharding_impls import ReplicaAxisContext
+from jax._src.source_info_util import new_name_stack
+from jax._src.util import wrap_name
+from jax.core import ClosedJaxpr
+from jax.interpreters.mlir import (
+ AxisContext,
+ LoweringParameters,
+ ModuleContext,
+ ir,
+ lower_jaxpr_to_fun,
+ lowerable_effects,
+)
+
+from catalyst.utils.patching import Patcher
+
+# pylint: disable=protected-access
+
+__all__ = ("jaxpr_to_mlir", "custom_lower_jaxpr_to_module")
+
+from catalyst.jax_extras.patches import _no_clean_up_dead_vars, get_aval2
+
+
+def jaxpr_to_mlir(func_name, jaxpr):
+ """Lower a Jaxpr into an MLIR module.
+
+ Args:
+ func_name(str): function name
+ jaxpr(Jaxpr): Jaxpr code to lower
+
+ Returns:
+ module: the MLIR module corresponding to ``func``
+ context: the MLIR context corresponding
+ """
+
+ with Patcher(
+ (jax._src.interpreters.partial_eval, "get_aval", get_aval2),
+ (jax._src.core, "clean_up_dead_vars", _no_clean_up_dead_vars),
+ ):
+ nrep = jaxpr_replicas(jaxpr)
+ effects = jax_ordered_effects.filter_in(jaxpr.effects)
+ axis_context = ReplicaAxisContext(xla.AxisEnv(nrep, (), ()))
+ name_stack = new_name_stack(wrap_name("ok", "jit"))
+ module, context = custom_lower_jaxpr_to_module(
+ func_name="jit_" + func_name,
+ module_name=func_name,
+ jaxpr=jaxpr,
+ effects=effects,
+ platform="cpu",
+ axis_context=axis_context,
+ name_stack=name_stack,
+ )
+
+ return module, context
+
+
+# pylint: disable=too-many-arguments
+def custom_lower_jaxpr_to_module(
+ func_name: str,
+ module_name: str,
+ jaxpr: ClosedJaxpr,
+ effects,
+ platform: str,
+ axis_context: AxisContext,
+ name_stack,
+ replicated_args=None,
+ arg_shardings=None,
+ result_shardings=None,
+):
+ """Lowers a top-level jaxpr to an MHLO module.
+
+ Handles the quirks of the argument/return value passing conventions of the
+ runtime.
+
+ This function has been modified from its original form in the JAX project at
+ https://github.com/google/jax/blob/c4d590b1b640cc9fcfdbe91bf3fe34c47bcde917/jax/interpreters/mlir.py#L625version
+ released under the Apache License, Version 2.0, with the following copyright notice:
+
+ Copyright 2021 The JAX Authors.
+ """
+
+ if any(lowerable_effects.filter_not_in(jaxpr.effects)): # pragma: no cover
+ raise ValueError(f"Cannot lower jaxpr with effects: {jaxpr.effects}")
+
+ assert platform == "cpu"
+ assert arg_shardings is None
+ assert result_shardings is None
+
+ # MHLO channels need to start at 1
+ channel_iter = 1
+ # Create a keepalives list that will be mutated during the lowering.
+ keepalives = []
+ host_callbacks = []
+ lowering_params = LoweringParameters()
+ ctx = ModuleContext(
+ backend_or_name=None,
+ platforms=[platform],
+ axis_context=axis_context,
+ name_stack=name_stack,
+ keepalives=keepalives,
+ channel_iterator=channel_iter,
+ host_callbacks=host_callbacks,
+ lowering_parameters=lowering_params,
+ )
+ ctx.context.allow_unregistered_dialects = True
+ with ctx.context, ir.Location.unknown(ctx.context):
+ # register_dialect()
+ # Remove module name characters that XLA would alter. This ensures that
+ # XLA computation preserves the module name.
+ module_name = _module_name_regex.sub("_", module_name)
+ ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)
+ lower_jaxpr_to_fun(
+ ctx,
+ func_name,
+ jaxpr,
+ effects,
+ public=True,
+ create_tokens=True,
+ replace_tokens_with_dummy=True,
+ replicated_args=replicated_args,
+ arg_shardings=arg_shardings,
+ result_shardings=result_shardings,
+ )
+
+ for op in ctx.module.body.operations:
+ func_name = str(op.name)
+ is_entry_point = func_name.startswith('"jit_')
+ if is_entry_point:
+ continue
+ op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage")
+
+ return ctx.module, ctx.context
diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py
new file mode 100644
index 0000000000..266f38b33c
--- /dev/null
+++ b/frontend/catalyst/jax_extras/patches.py
@@ -0,0 +1,182 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Jax extras module containing Jax patches """
+
+# pylint: disable=too-many-arguments
+
+from __future__ import annotations
+
+import jax
+from jax._src.lax.slicing import (
+ _gather_shape_computation,
+ _is_sorted,
+ _no_duplicate_dims,
+ _rank,
+ _sorted_dims_in_range,
+)
+from jax.core import AbstractValue, Tracer, concrete_aval
+
+__all__ = (
+ "get_aval2",
+ "_no_clean_up_dead_vars",
+ "_gather_shape_rule_dynamic",
+)
+
+
+def get_aval2(x):
+ """An extended version of `jax.core.get_aval` which also accepts AbstractValues."""
+ # TODO: remove this patch when https://github.com/google/jax/pull/18579 is merged
+ if isinstance(x, AbstractValue):
+ return x
+ elif isinstance(x, Tracer):
+ return x.aval
+ else:
+ return concrete_aval(x)
+
+
+def _no_clean_up_dead_vars(_eqn, _env, _last_used):
+ """A stub to workaround the Jax ``KeyError 'a'`` bug during the lowering of Jaxpr programs to
+ MLIR with the dynamic API enabled."""
+ return None
+
+
+def _gather_shape_rule_dynamic(
+ operand,
+ indices,
+ *,
+ dimension_numbers,
+ slice_sizes,
+ unique_indices,
+ indices_are_sorted,
+ mode,
+ fill_value,
+): # pragma: no cover
+ """Validates the well-formedness of the arguments to Gather. Compared to the original version,
+ this implementation skips static shape checks if variable dimensions are used.
+
+ This function has been modified from its original form in the JAX project at
+ https://github.com/google/jax/blob/88a60b808c1f91260cc9e75b9aa2508aae5bc9f9/jax/_src/lax/slicing.py#L1438
+ version released under the Apache License, Version 2.0, with the following copyright notice:
+
+ Copyright 2021 The JAX Authors.
+ TODO(@grwlf): delete once PR https://github.com/google/jax/pull/19083 has been merged
+ """
+ # pylint: disable=unused-argument
+ # pylint: disable=too-many-branches
+ # pylint: disable=consider-using-enumerate
+ # pylint: disable=chained-comparison
+ offset_dims = dimension_numbers.offset_dims
+ collapsed_slice_dims = dimension_numbers.collapsed_slice_dims
+ start_index_map = dimension_numbers.start_index_map
+
+ # Note: in JAX, index_vector_dim is always computed as below, cf. the
+ # documentation of the GatherDimensionNumbers class.
+ index_vector_dim = _rank(indices) - 1
+
+ # This case should never happen in JAX, due to the implicit construction of
+ # index_vector_dim, but is included for completeness.
+ if _rank(indices) < index_vector_dim or index_vector_dim < 0:
+ raise TypeError(
+ f"Gather index leaf dimension must be within [0, rank("
+ f"indices) + 1). rank(indices) is {_rank(indices)} and "
+ f"gather index leaf dimension is {index_vector_dim}."
+ )
+
+ # Start ValidateGatherDimensions
+ # In the error messages output by XLA, "offset_dims" is called "Output window
+ # dimensions" in error messages. For consistency's sake, our error messages
+ # stick to "offset_dims".
+ _is_sorted(offset_dims, "gather", "offset_dims")
+ _no_duplicate_dims(offset_dims, "gather", "offset_dims")
+
+ output_offset_dim_count = len(offset_dims)
+ output_shape_rank = len(offset_dims) + _rank(indices) - 1
+
+ for i in range(output_offset_dim_count):
+ offset_dim = offset_dims[i]
+ if offset_dim < 0 or offset_dim >= output_shape_rank:
+ raise TypeError(
+ f"Offset dimension {i} in gather op is out of bounds; "
+ f"got {offset_dim}, but should have been in "
+ f"[0, {output_shape_rank})"
+ )
+
+ if len(start_index_map) != indices.shape[index_vector_dim]:
+ raise TypeError(
+ f"Gather op has {len(start_index_map)} elements in "
+ f"start_index_map and the bound of dimension "
+ f"{index_vector_dim=} of indices is "
+ f"{indices.shape[index_vector_dim]}. These two "
+ f"numbers must be equal."
+ )
+
+ for i in range(len(start_index_map)):
+ operand_dim_for_start_index_i = start_index_map[i]
+ if operand_dim_for_start_index_i < 0 or operand_dim_for_start_index_i >= _rank(operand):
+ raise TypeError(
+ f"Invalid start_index_map; domain is "
+ f"[0, {_rank(operand)}), got: "
+ f"{i}->{operand_dim_for_start_index_i}."
+ )
+
+ _no_duplicate_dims(start_index_map, "gather", "start_index_map")
+
+ # _is_sorted and _sorted_dims_in_range are checked in the opposite order
+ # compared to the XLA implementation. In cases when the input is not sorted
+ # AND there are problematic collapsed_slice_dims, the error message will thus
+ # be different.
+ _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims")
+ _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims")
+ _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims")
+ # End ValidateGatherDimensions
+
+ if _rank(operand) != len(slice_sizes):
+ raise TypeError(
+ f"Gather op must have one slice size for every input "
+ f"dimension; got: len(slice_sizes)={len(slice_sizes)}, "
+ f"input_shape.rank={_rank(operand)}"
+ )
+
+ if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims):
+ raise TypeError(
+ f"All components of the offset index in a gather op must "
+ f"either be a offset dimension or explicitly collapsed; "
+ f"got len(slice_sizes)={len(slice_sizes)}, "
+ f"output_slice_sizes={offset_dims}, collapsed_slice_dims="
+ f"{collapsed_slice_dims}."
+ )
+
+ # This section contains a patch suggested to the upstream.
+ for i in range(len(slice_sizes)):
+ slice_size = slice_sizes[i]
+ corresponding_input_size = operand.shape[i]
+
+ if jax.core.is_constant_dim(corresponding_input_size):
+ if not (slice_size >= 0 and corresponding_input_size >= slice_size):
+ raise TypeError(
+ f"Slice size at index {i} in gather op is out of range, "
+ f"must be within [0, {corresponding_input_size} + 1), "
+ f"got {slice_size}."
+ )
+
+ for i in range(len(collapsed_slice_dims)):
+ bound = slice_sizes[collapsed_slice_dims[i]]
+ if bound != 1:
+ raise TypeError(
+ f"Gather op can only collapse slice dims with bound 1, "
+ f"but bound is {bound} for index "
+ f"{collapsed_slice_dims[i]} at position {i}."
+ )
+
+ return _gather_shape_computation(indices, dimension_numbers, slice_sizes)
diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/jax_extras/tracing.py
similarity index 57%
rename from frontend/catalyst/utils/jax_extras.py
rename to frontend/catalyst/jax_extras/tracing.py
index c3ed6fab33..2fb510ad14 100644
--- a/frontend/catalyst/utils/jax_extras.py
+++ b/frontend/catalyst/jax_extras/tracing.py
@@ -1,4 +1,4 @@
-# Copyright 2023 Xanadu Quantum Technologies Inc.
+# Copyright 2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""This module isolates utility functions that depend on JAX low-level internals
-"""
+""" Jax extras module containing functions related to the Python program tracing """
+
from __future__ import annotations
from contextlib import ExitStack, contextmanager
@@ -22,8 +22,6 @@
from jax import ShapeDtypeStruct
from jax._src import state, util
from jax._src.core import _update_thread_local_jit_state
-from jax._src.dispatch import jaxpr_replicas
-from jax._src.effects import ordered_effects as jax_ordered_effects
from jax._src.interpreters.mlir import _module_name_regex, register_lowering
from jax._src.interpreters.partial_eval import (
_input_type_to_tracers,
@@ -31,44 +29,21 @@
trace_to_jaxpr_dynamic2,
)
from jax._src.lax.control_flow import _initial_style_jaxpr, _initial_style_open_jaxpr
-from jax._src.lax.lax import _abstractify, xla
+from jax._src.lax.lax import _abstractify
from jax._src.lax.slicing import (
_argnum_weak_type,
_gather_dtype_rule,
_gather_lower,
- _gather_shape_computation,
- _is_sorted,
- _no_duplicate_dims,
- _rank,
- _sorted_dims_in_range,
standard_primitive,
)
from jax._src.linear_util import annotate
from jax._src.pjit import _extract_implicit_args, _flat_axes_specs
-from jax._src.sharding_impls import ReplicaAxisContext
from jax._src.source_info_util import current as jax_current
-from jax._src.source_info_util import new_name_stack
-from jax._src.util import partition_list, safe_map, unzip2, unzip3, wrap_name, wraps
+from jax._src.util import partition_list, safe_map, unzip2, unzip3, wraps
from jax.api_util import flatten_fun
-from jax.core import AbstractValue, ClosedJaxpr, Jaxpr, JaxprEqn, MainTrace, OutputType
+from jax.core import ClosedJaxpr, Jaxpr, JaxprEqn, MainTrace, OutputType
from jax.core import Primitive as JaxprPrimitive
-from jax.core import (
- ShapedArray,
- Trace,
- Tracer,
- concrete_aval,
- eval_jaxpr,
- gensym,
- thread_local_state,
-)
-from jax.interpreters.mlir import (
- AxisContext,
- LoweringParameters,
- ModuleContext,
- ir,
- lower_jaxpr_to_fun,
- lowerable_effects,
-)
+from jax.core import ShapedArray, Trace, eval_jaxpr, gensym, thread_local_state
from jax.interpreters.partial_eval import (
DynamicJaxprTrace,
DynamicJaxprTracer,
@@ -76,7 +51,7 @@
make_jaxpr_effects,
)
from jax.lax import convert_element_type
-from jax.linear_util import wrap_init
+from jax.extend.linear_util import wrap_init
from jax.tree_util import (
PyTreeDef,
tree_flatten,
@@ -86,6 +61,7 @@
)
from jaxlib.xla_extension import PyTreeRegistry
+from catalyst.jax_extras.patches import _gather_shape_rule_dynamic, get_aval2
from catalyst.utils.patching import Patcher
# pylint: disable=protected-access
@@ -108,7 +84,7 @@
"_abstractify",
"_initial_style_jaxpr",
"_input_type_to_tracers",
- "jaxpr_to_mlir",
+ "_module_name_regex",
"make_jaxpr_effects",
"make_jaxpr2",
"new_dynamic_main2",
@@ -347,6 +323,7 @@ def deduce_avals(f: Callable, args, kwargs):
"""Wraps the callable ``f`` into a WrappedFun container accepting collapsed flatten arguments
and returning expanded flatten results. Calculate input abstract values and output_tree promise.
The promise must be called after the resulting wrapped function is evaluated."""
+ # TODO: deprecate in favor of `deduce_signatures`
flat_args, in_tree = tree_flatten((args, kwargs))
abstracted_axes = None
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
@@ -358,134 +335,6 @@ def deduce_avals(f: Callable, args, kwargs):
return wffa, in_avals, keep_inputs, out_tree_promise
-def get_aval2(x):
- """An extended version of `jax.core.get_aval` which also accepts AbstractValues."""
- # TODO: remove this patch when https://github.com/google/jax/pull/18579 is merged
- if isinstance(x, AbstractValue):
- return x
- elif isinstance(x, Tracer):
- return x.aval
- else:
- return concrete_aval(x)
-
-
-def _no_clean_up_dead_vars(_eqn, _env, _last_used):
- """A stub to workaround the Jax ``KeyError 'a'`` bug during the lowering of Jaxpr programs to
- MLIR with the dynamic API enabled."""
- return None
-
-
-def jaxpr_to_mlir(func_name, jaxpr):
- """Lower a Jaxpr into an MLIR module.
-
- Args:
- func_name(str): function name
- jaxpr(Jaxpr): Jaxpr code to lower
-
- Returns:
- module: the MLIR module corresponding to ``func``
- context: the MLIR context corresponding
- """
-
- with Patcher(
- (jax._src.interpreters.partial_eval, "get_aval", get_aval2),
- (jax._src.core, "clean_up_dead_vars", _no_clean_up_dead_vars),
- ):
- nrep = jaxpr_replicas(jaxpr)
- effects = jax_ordered_effects.filter_in(jaxpr.effects)
- axis_context = ReplicaAxisContext(xla.AxisEnv(nrep, (), ()))
- name_stack = new_name_stack(wrap_name("ok", "jit"))
- module, context = custom_lower_jaxpr_to_module(
- func_name="jit_" + func_name,
- module_name=func_name,
- jaxpr=jaxpr,
- effects=effects,
- platform="cpu",
- axis_context=axis_context,
- name_stack=name_stack,
- )
-
- return module, context
-
-
-# pylint: disable=too-many-arguments
-def custom_lower_jaxpr_to_module(
- func_name: str,
- module_name: str,
- jaxpr: ClosedJaxpr,
- effects,
- platform: str,
- axis_context: AxisContext,
- name_stack,
- replicated_args=None,
- arg_shardings=None,
- result_shardings=None,
-):
- """Lowers a top-level jaxpr to an MHLO module.
-
- Handles the quirks of the argument/return value passing conventions of the
- runtime.
-
- This function has been modified from its original form in the JAX project at
- https://github.com/google/jax/blob/c4d590b1b640cc9fcfdbe91bf3fe34c47bcde917/jax/interpreters/mlir.py#L625version
- released under the Apache License, Version 2.0, with the following copyright notice:
-
- Copyright 2021 The JAX Authors.
- """
-
- if any(lowerable_effects.filter_not_in(jaxpr.effects)): # pragma: no cover
- raise ValueError(f"Cannot lower jaxpr with effects: {jaxpr.effects}")
-
- assert platform == "cpu"
- assert arg_shardings is None
- assert result_shardings is None
-
- # MHLO channels need to start at 1
- channel_iter = 1
- # Create a keepalives list that will be mutated during the lowering.
- keepalives = []
- host_callbacks = []
- lowering_params = LoweringParameters()
- ctx = ModuleContext(
- backend_or_name=None,
- platforms=[platform],
- axis_context=axis_context,
- name_stack=name_stack,
- keepalives=keepalives,
- channel_iterator=channel_iter,
- host_callbacks=host_callbacks,
- lowering_parameters=lowering_params,
- )
- ctx.context.allow_unregistered_dialects = True
- with ctx.context, ir.Location.unknown(ctx.context):
- # register_dialect()
- # Remove module name characters that XLA would alter. This ensures that
- # XLA computation preserves the module name.
- module_name = _module_name_regex.sub("_", module_name)
- ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)
- lower_jaxpr_to_fun(
- ctx,
- func_name,
- jaxpr,
- effects,
- public=True,
- create_tokens=True,
- replace_tokens_with_dummy=True,
- replicated_args=replicated_args,
- arg_shardings=arg_shardings,
- result_shardings=result_shardings,
- )
-
- for op in ctx.module.body.operations:
- func_name = str(op.name)
- is_entry_point = func_name.startswith('"jit_')
- if is_entry_point:
- continue
- op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage")
-
- return ctx.module, ctx.context
-
-
def new_inner_tracer(trace: DynamicJaxprTrace, aval) -> DynamicJaxprTracer:
"""Create a JAX tracer tracing an abstract value ``aval`, without specifying its source
primitive."""
@@ -549,133 +398,3 @@ def make_jaxpr_f(*args, **kwargs):
make_jaxpr_f.__name__ = f"make_jaxpr2({make_jaxpr2.__name__})"
return make_jaxpr_f
-
-
-def _gather_shape_rule_dynamic(
- operand,
- indices,
- *,
- dimension_numbers,
- slice_sizes,
- unique_indices,
- indices_are_sorted,
- mode,
- fill_value,
-): # pragma: no cover
- """Validates the well-formedness of the arguments to Gather. Compared to the original version,
- this implementation skips static shape checks if variable dimensions are used.
-
- This function has been modified from its original form in the JAX project at
- https://github.com/google/jax/blob/88a60b808c1f91260cc9e75b9aa2508aae5bc9f9/jax/_src/lax/slicing.py#L1438
- version released under the Apache License, Version 2.0, with the following copyright notice:
-
- Copyright 2021 The JAX Authors.
- """
- # pylint: disable=unused-argument
- # pylint: disable=too-many-branches
- # pylint: disable=consider-using-enumerate
- # pylint: disable=chained-comparison
- offset_dims = dimension_numbers.offset_dims
- collapsed_slice_dims = dimension_numbers.collapsed_slice_dims
- start_index_map = dimension_numbers.start_index_map
-
- # Note: in JAX, index_vector_dim is always computed as below, cf. the
- # documentation of the GatherDimensionNumbers class.
- index_vector_dim = _rank(indices) - 1
-
- # This case should never happen in JAX, due to the implicit construction of
- # index_vector_dim, but is included for completeness.
- if _rank(indices) < index_vector_dim or index_vector_dim < 0:
- raise TypeError(
- f"Gather index leaf dimension must be within [0, rank("
- f"indices) + 1). rank(indices) is {_rank(indices)} and "
- f"gather index leaf dimension is {index_vector_dim}."
- )
-
- # Start ValidateGatherDimensions
- # In the error messages output by XLA, "offset_dims" is called "Output window
- # dimensions" in error messages. For consistency's sake, our error messages
- # stick to "offset_dims".
- _is_sorted(offset_dims, "gather", "offset_dims")
- _no_duplicate_dims(offset_dims, "gather", "offset_dims")
-
- output_offset_dim_count = len(offset_dims)
- output_shape_rank = len(offset_dims) + _rank(indices) - 1
-
- for i in range(output_offset_dim_count):
- offset_dim = offset_dims[i]
- if offset_dim < 0 or offset_dim >= output_shape_rank:
- raise TypeError(
- f"Offset dimension {i} in gather op is out of bounds; "
- f"got {offset_dim}, but should have been in "
- f"[0, {output_shape_rank})"
- )
-
- if len(start_index_map) != indices.shape[index_vector_dim]:
- raise TypeError(
- f"Gather op has {len(start_index_map)} elements in "
- f"start_index_map and the bound of dimension "
- f"{index_vector_dim=} of indices is "
- f"{indices.shape[index_vector_dim]}. These two "
- f"numbers must be equal."
- )
-
- for i in range(len(start_index_map)):
- operand_dim_for_start_index_i = start_index_map[i]
- if operand_dim_for_start_index_i < 0 or operand_dim_for_start_index_i >= _rank(operand):
- raise TypeError(
- f"Invalid start_index_map; domain is "
- f"[0, {_rank(operand)}), got: "
- f"{i}->{operand_dim_for_start_index_i}."
- )
-
- _no_duplicate_dims(start_index_map, "gather", "start_index_map")
-
- # _is_sorted and _sorted_dims_in_range are checked in the opposite order
- # compared to the XLA implementation. In cases when the input is not sorted
- # AND there are problematic collapsed_slice_dims, the error message will thus
- # be different.
- _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims")
- _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims")
- _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims")
- # End ValidateGatherDimensions
-
- if _rank(operand) != len(slice_sizes):
- raise TypeError(
- f"Gather op must have one slice size for every input "
- f"dimension; got: len(slice_sizes)={len(slice_sizes)}, "
- f"input_shape.rank={_rank(operand)}"
- )
-
- if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims):
- raise TypeError(
- f"All components of the offset index in a gather op must "
- f"either be a offset dimension or explicitly collapsed; "
- f"got len(slice_sizes)={len(slice_sizes)}, "
- f"output_slice_sizes={offset_dims}, collapsed_slice_dims="
- f"{collapsed_slice_dims}."
- )
-
- # This section contains a patch suggested to the upstream.
- for i in range(len(slice_sizes)):
- slice_size = slice_sizes[i]
- corresponding_input_size = operand.shape[i]
-
- if jax.core.is_constant_dim(corresponding_input_size):
- if not (slice_size >= 0 and corresponding_input_size >= slice_size):
- raise TypeError(
- f"Slice size at index {i} in gather op is out of range, "
- f"must be within [0, {corresponding_input_size} + 1), "
- f"got {slice_size}."
- )
-
- for i in range(len(collapsed_slice_dims)):
- bound = slice_sizes[collapsed_slice_dims[i]]
- if bound != 1:
- raise TypeError(
- f"Gather op can only collapse slice dims with bound 1, "
- f"but bound is {bound} for index "
- f"{collapsed_slice_dims[i]} at position {i}."
- )
-
- return _gather_shape_computation(indices, dimension_numbers, slice_sizes)
diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py
index 4748a46a15..cd469a38f7 100644
--- a/frontend/catalyst/jax_tracer.py
+++ b/frontend/catalyst/jax_tracer.py
@@ -27,6 +27,28 @@
from pennylane.tape import QuantumTape
import catalyst
+from catalyst.jax_extras import (
+ ClosedJaxpr,
+ DynshapedJaxpr,
+ DynamicJaxprTrace,
+ DynamicJaxprTracer,
+ PyTreeDef,
+ PyTreeRegistry,
+ ShapedArray,
+ _abstractify,
+ _input_type_to_tracers,
+ convert_element_type,
+ deduce_avals,
+ eval_jaxpr,
+ jaxpr_to_mlir,
+ make_jaxpr2,
+ sort_eqns,
+ transient_jax_config,
+ tree_flatten,
+ tree_structure,
+ tree_unflatten,
+ wrap_init,
+)
from catalyst.jax_primitives import (
AbstractQreg,
compbasis_p,
@@ -58,28 +80,6 @@
JaxTracingContext,
)
from catalyst.utils.exceptions import CompileError
-from catalyst.utils.jax_extras import (
- ClosedJaxpr,
- DynamicJaxprTrace,
- DynamicJaxprTracer,
- DynshapedJaxpr,
- PyTreeDef,
- PyTreeRegistry,
- ShapedArray,
- _abstractify,
- _input_type_to_tracers,
- convert_element_type,
- deduce_avals,
- eval_jaxpr,
- jaxpr_to_mlir,
- make_jaxpr2,
- sort_eqns,
- transient_jax_config,
- tree_flatten,
- tree_structure,
- tree_unflatten,
- wrap_init,
-)
class Function:
diff --git a/frontend/catalyst/pennylane_extensions.py b/frontend/catalyst/pennylane_extensions.py
index 176c747775..e141439359 100644
--- a/frontend/catalyst/pennylane_extensions.py
+++ b/frontend/catalyst/pennylane_extensions.py
@@ -44,6 +44,21 @@
from pennylane.tape import QuantumTape
import catalyst
+from catalyst.jax_extras import ( # infer_output_type3,
+ ClosedJaxpr,
+ DynamicJaxprTracer,
+ Jaxpr,
+ ShapedArray,
+ _initial_style_jaxpr,
+ _input_type_to_tracers,
+ convert_constvars_jaxpr,
+ deduce_avals,
+ get_implicit_and_explicit_flat_args,
+ initial_style_jaxprs_with_common_consts1,
+ initial_style_jaxprs_with_common_consts2,
+ new_inner_tracer,
+ unzip2,
+)
from catalyst.jax_primitives import (
AbstractQreg,
GradParams,
@@ -65,7 +80,6 @@
HybridOp,
HybridOpRegion,
QRegPromise,
- deduce_avals,
has_nested_tapes,
trace_quantum_function,
trace_quantum_tape,
@@ -78,20 +92,6 @@
JaxTracingContext,
)
from catalyst.utils.exceptions import DifferentiableCompileError
-from catalyst.utils.jax_extras import (
- ClosedJaxpr,
- DynamicJaxprTracer,
- Jaxpr,
- ShapedArray,
- _initial_style_jaxpr,
- _input_type_to_tracers,
- convert_constvars_jaxpr,
- get_implicit_and_explicit_flat_args,
- initial_style_jaxprs_with_common_consts1,
- initial_style_jaxprs_with_common_consts2,
- new_inner_tracer,
- unzip2,
-)
from catalyst.utils.runtime import extract_backend_info, get_lib_path
diff --git a/frontend/catalyst/tracing/contexts.py b/frontend/catalyst/tracing/contexts.py
index 7f388a82fe..749c8b0f10 100644
--- a/frontend/catalyst/tracing/contexts.py
+++ b/frontend/catalyst/tracing/contexts.py
@@ -32,8 +32,8 @@
from jax.core import find_top_trace
from pennylane.queuing import QueuingManager
+from catalyst.jax_extras import new_dynamic_main2
from catalyst.utils.exceptions import CompileError
-from catalyst.utils.jax_extras import new_dynamic_main2
class EvaluationMode(Enum):
diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py
index 3d730fe78e..0396a95592 100644
--- a/frontend/catalyst/tracing/type_signatures.py
+++ b/frontend/catalyst/tracing/type_signatures.py
@@ -27,7 +27,7 @@
from jax.api_util import shaped_abstractify
from jax.tree_util import tree_flatten, tree_unflatten
-from catalyst.utils.jax_extras import get_aval2
+from catalyst.jax_extras import get_aval2
from catalyst.utils.patching import Patcher
diff --git a/frontend/test/pytest/test_cuda_integration.py b/frontend/test/pytest/test_cuda_integration.py
index 2d12a6b89e..910eb21ab6 100644
--- a/frontend/test/pytest/test_cuda_integration.py
+++ b/frontend/test/pytest/test_cuda_integration.py
@@ -27,6 +27,7 @@
# when we are running kokkos. Importing CUDA before running any kokkos
# kernel polutes the environment and will create a segfault.
# pylint: disable=import-outside-toplevel
+# pylint: disable=too-many-public-methods
@pytest.mark.cuda
@@ -49,13 +50,12 @@ def circuit_foo():
def test_qjit_cuda_remove_host_context(self):
"""Test removing the host context."""
- from catalyst.cuda import SoftwareQQPP
from catalyst.cuda.catalyst_to_cuda_interpreter import (
QJIT_CUDAQ,
remove_host_context,
)
- @qml.qnode(SoftwareQQPP(wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit_foo():
return qml.state()
@@ -65,10 +65,9 @@ def circuit_foo():
def test_qjit_catalyst_to_cuda_jaxpr(self):
"""Assert that catalyst_to_cuda returns something."""
- from catalyst.cuda import SoftwareQQPP
from catalyst.cuda.catalyst_to_cuda_interpreter import interpret
- @qml.qnode(SoftwareQQPP(wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit_foo():
return qml.state()
@@ -78,12 +77,11 @@ def circuit_foo():
def test_measurement_return(self):
"""Test the measurement code is added."""
- from catalyst.cuda import SoftwareQQPP
from catalyst.cuda.catalyst_to_cuda_interpreter import interpret
with pytest.raises(NotImplementedError, match="cannot return measurements directly"):
- @qml.qnode(SoftwareQQPP(wires=1, shots=30))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1, shots=30))
def circuit():
qml.RX(jnp.pi / 4, wires=[0])
return measure(0)
@@ -93,10 +91,9 @@ def circuit():
def test_measurement_side_effect(self):
"""Test the measurement code is added."""
- from catalyst.cuda import SoftwareQQPP
from catalyst.cuda.catalyst_to_cuda_interpreter import interpret
- @qml.qnode(SoftwareQQPP(wires=1, shots=30))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1, shots=30))
def circuit():
qml.RX(jnp.pi / 4, wires=[0])
measure(0)
@@ -106,9 +103,8 @@ def circuit():
def test_pytrees(self):
"""Test that we can return a dictionary."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit_a(a):
qml.RX(a, wires=[0])
return {"a": qml.state()}
@@ -126,9 +122,8 @@ def circuit_b(a):
def test_cuda_device(self):
"""Test SoftwareQQPP."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit(a):
qml.RX(a, wires=[0])
return qml.state()
@@ -146,9 +141,8 @@ def circuit_lightning(a):
def test_samples(self):
"""Test SoftwareQQPP."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=1, shots=100))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1, shots=100))
def circuit(a):
qml.RX(a, wires=[0])
return qml.sample()
@@ -166,9 +160,8 @@ def circuit_lightning(a):
def test_counts(self):
"""Test SoftwareQQPP."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=1, shots=100))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1, shots=100))
def circuit(a):
qml.RX(a, wires=[0])
return qml.counts()
@@ -186,9 +179,8 @@ def circuit_lightning(a):
def test_qjit_cuda_device(self):
"""Test SoftwareQQPP."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit(a):
qml.RX(a, wires=[0])
return qml.state()
@@ -206,9 +198,8 @@ def circuit_lightning(a):
def test_abstract_variable(self):
"""Test abstract variable."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit(a: float):
qml.RX(a, wires=[0])
return qml.state()
@@ -226,9 +217,8 @@ def circuit_lightning(a):
def test_arithmetic(self):
"""Test arithmetic."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit(a):
qml.RX(a / 2, wires=[0])
return qml.state()
@@ -246,9 +236,8 @@ def circuit_lightning(a):
def test_multiple_values(self):
"""Test multiple_values."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit(params):
x, y = jax.numpy.array_split(params, 2)
qml.RX(x[0], wires=[0])
@@ -272,7 +261,7 @@ def circuit_lightning(params):
def test_cuda_device_entry_point(self):
"""Test the entry point for SoftwareQQPP"""
- @qml.qnode(qml.device("software.qpp", wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit(a):
qml.RX(a, wires=[0])
return {"a": qml.state()}
@@ -293,7 +282,7 @@ def test_cuda_device_entry_point_compiler(self):
"""Test the entry point for cudaq.qjit"""
@qml.qjit(compiler="cuda_quantum")
- @qml.qnode(qml.device("cudaq", wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit(a):
qml.RX(a, wires=[0])
return {"a": qml.state()}
@@ -302,9 +291,8 @@ def circuit(a):
def test_expval(self):
"""Test multiple_values."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=1))
+ @qml.qnode(qml.device("softwareq.qpp", wires=1))
def circuit():
qml.RX(jnp.pi / 2, wires=[0])
return qml.expval(qml.PauliZ(0))
@@ -322,9 +310,8 @@ def circuit_catalyst():
def test_expval_2(self):
"""Test multiple_values."""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=2))
+ @qml.qnode(qml.device("softwareq.qpp", wires=2))
def circuit():
qml.RY(jnp.pi / 4, wires=[1])
return qml.expval(qml.PauliZ(1) + qml.PauliX(1))
@@ -343,9 +330,7 @@ def circuit_catalyst():
def test_adjoint(self):
"""Test adjoint."""
- from catalyst.cuda import SoftwareQQPP
-
- @qml.qnode(SoftwareQQPP(wires=2))
+ @qml.qnode(qml.device("softwareq.qpp", wires=2))
def circuit():
def f(theta):
qml.RX(theta / 23, wires=[0])
@@ -380,9 +365,7 @@ def f(theta):
def test_control_ry(self):
"""Test control ry."""
- from catalyst.cuda import SoftwareQQPP
-
- @qml.qnode(SoftwareQQPP(wires=2))
+ @qml.qnode(qml.device("softwareq.qpp", wires=2))
def circuit():
qml.Hadamard(wires=[0])
qml.CRY(jnp.pi / 2, wires=[0, 1])
@@ -403,9 +386,7 @@ def circuit_catalyst():
def test_swap(self):
"""Test swap."""
- from catalyst.cuda import SoftwareQQPP
-
- @qml.qnode(SoftwareQQPP(wires=2))
+ @qml.qnode(qml.device("softwareq.qpp", wires=2))
def circuit():
qml.RX(jnp.pi / 3, wires=[0])
qml.SWAP(wires=[0, 1])
@@ -426,9 +407,7 @@ def circuit_catalyst():
def test_entanglement(self):
"""Test swap."""
- from catalyst.cuda import SoftwareQQPP
-
- @qml.qnode(SoftwareQQPP(wires=2))
+ @qml.qnode(qml.device("softwareq.qpp", wires=2))
def circuit():
qml.Hadamard(wires=[0])
qml.CNOT(wires=[0, 1])
@@ -446,11 +425,47 @@ def circuit_catalyst():
expected = catalyst_compiled()
assert_allclose(expected, observed)
+ def test_cswap(self):
+ """Test cswap."""
+
+ @qml.qnode(qml.device("softwareq.qpp", wires=3))
+ def circuit():
+ qml.Hadamard(wires=[0])
+ qml.RX(jnp.pi / 7, wires=[1])
+ qml.CSWAP(wires=[0, 1, 2])
+ return qml.state()
+
+ @qml.qnode(qml.device("lightning.qubit", wires=3))
+ def circuit_catalyst():
+ qml.Hadamard(wires=[0])
+ qml.RX(jnp.pi / 7, wires=[1])
+ qml.CSWAP(wires=[0, 1, 2])
+ return qml.state()
+
+ cuda_compiled = catalyst.cuda.qjit(circuit)
+ observed = cuda_compiled()
+ catalyst_compiled = qjit(circuit_catalyst)
+ expected = catalyst_compiled()
+ assert_allclose(expected, observed)
+
+ def test_state_is_jax_array(self):
+ """Test return type for state."""
+
+ @qml.qnode(qml.device("softwareq.qpp", wires=3))
+ def circuit():
+ qml.Hadamard(wires=[0])
+ qml.RX(jnp.pi / 7, wires=[1])
+ qml.CSWAP(wires=[0, 1, 2])
+ return qml.state()
+
+ cuda_compiled = catalyst.cuda.qjit(circuit)
+ observed = cuda_compiled()
+ assert isinstance(observed, jax.Array)
+
def test_error_message_using_host_context(self):
"""Test error message"""
- from catalyst.cuda import SoftwareQQPP
- @qml.qnode(SoftwareQQPP(wires=2))
+ @qml.qnode(qml.device("softwareq.qpp", wires=2))
def circuit(x):
qml.Hadamard(wires=[0])
qml.CNOT(wires=[0, 1])
@@ -464,6 +479,28 @@ def wrapper(y):
with pytest.raises(CompileError, match="Cannot translate tapes with context"):
catalyst.cuda.qjit(wrapper)(1.0)
+ def test_samples(self):
+ """Samples with more than one wire."""
+
+ from catalyst.cuda import qjit as cjit
+
+ @qjit
+ @qml.qnode(qml.device("lightning.qubit", wires=2, shots=10))
+ def circuit1(a):
+ qml.RX(a, wires=0)
+ return qml.sample()
+
+ expected = circuit1(3.14)
+
+ @cjit
+ @qml.qnode(qml.device("softwareq.qpp", wires=2, shots=10))
+ def circuit2(a):
+ qml.RX(a, wires=0)
+ return qml.sample()
+
+ observed = circuit2(3.14)
+ assert_allclose(expected, observed)
+
if __name__ == "__main__":
pytest.main(["-x", __file__])
diff --git a/frontend/test/pytest/test_jax_config.py b/frontend/test/pytest/test_jax_config.py
index 0ae72c5a46..498367521c 100644
--- a/frontend/test/pytest/test_jax_config.py
+++ b/frontend/test/pytest/test_jax_config.py
@@ -16,7 +16,7 @@
import jax
-from catalyst.utils.jax_extras import transient_jax_config
+from catalyst.jax_extras import transient_jax_config
def test_transient_jax_config():
diff --git a/setup.py b/setup.py
index c93d92198e..4f9f35162d 100644
--- a/setup.py
+++ b/setup.py
@@ -16,7 +16,7 @@
import platform
import subprocess
from distutils import sysconfig
-from os import environ, path
+from os import path
import numpy as np
from pybind11.setup_helpers import intree_extensions
@@ -33,43 +33,35 @@
version = f.readlines()[-1].split()[-1].strip("\"'")
with open(".dep-versions") as f:
- jax_version = [line[4:].strip() for line in f.readlines() if "jax=" in line][0]
+ lines = f.readlines()
+ jax_version = [line[4:].strip() for line in lines if "jax=" in line][0]
+ pl_str = "pennylane="
+ pl_str_length = len(pl_str)
+ pl_version = [line[pl_str_length:].strip() for line in lines if pl_str in line][0]
-pl_version = environ.get("PL_VERSION", ">=0.32,<=0.34")
requirements = [
- f"pennylane{pl_version}",
+ f"pennylane @ git+https://github.com/pennylaneai/pennylane@{pl_version}",
f"jax=={jax_version}",
f"jaxlib=={jax_version}",
"tomlkit;python_version<'3.11'",
"scipy",
]
-# TODO: Once PL version 0.35 is released:
-# * remove this special handling
-# * make pennylane>=0.35 a requirement
-# * Close this ticket https://github.com/PennyLaneAI/catalyst/issues/494
-one_compiler_per_distribution = pl_version == ">=0.32,<=0.34"
-if one_compiler_per_distribution:
- entry_points = {
- "pennylane.plugins": "softwareq.qpp = catalyst.cuda:SoftwareQQPP",
- "pennylane.compilers": [
- "context = catalyst.tracing.contexts:EvaluationContext",
- "ops = catalyst:pennylane_extensions",
- "qjit = catalyst:qjit",
- ],
- }
-else:
- entry_points = {
- "pennylane.plugins": "softwareq.qpp = catalyst.cuda:SoftwareQQPP",
- "pennylane.compilers": [
- "catalyst.context = catalyst.tracing.contexts:EvaluationContext",
- "catalyst.ops = catalyst:pennylane_extensions",
- "catalyst.qjit = catalyst:qjit",
- "cuda_quantum.context = catalyst.cuda:EvaluationContext",
- "cuda_quantum.ops = catalyst.cuda:pennylane_extensions",
- "cuda_quantum.qjit = catalyst.cuda:qjit",
- ],
- }
+entry_points = {
+ "pennylane.plugins": [
+ "softwareq.qpp = catalyst.cuda:SoftwareQQPP",
+ "nvidia.statevec = catalyst.cuda:NvidiaCuStateVec",
+ "nvidia.tensornet = catalyst.cuda:NvidiaCuTensorNet",
+ ],
+ "pennylane.compilers": [
+ "catalyst.context = catalyst.tracing.contexts:EvaluationContext",
+ "catalyst.ops = catalyst:pennylane_extensions",
+ "catalyst.qjit = catalyst:qjit",
+ "cuda_quantum.context = catalyst.tracing.contexts:EvaluationContext",
+ "cuda_quantum.ops = catalyst:pennylane_extensions",
+ "cuda_quantum.qjit = catalyst.cuda:qjit",
+ ],
+}
classifiers = [
"Environment :: Console",