diff --git a/doc/changelog.md b/doc/changelog.md
index db3069494f..af5c2dcde7 100644
--- a/doc/changelog.md
+++ b/doc/changelog.md
@@ -92,11 +92,14 @@
annotations.
[(#751)](https://github.com/PennyLaneAI/catalyst/pull/751)
-* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable
- class that implements the decorator's logic. This prevents having to excessively define
+* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable
+ class that implements the decorator's logic. This prevents having to excessively define
functions in a nested fashion.
[(#758)](https://github.com/PennyLaneAI/catalyst/pull/758)
+* Catalyst tests now manipulate device capabilities rather than text configurations files.
+ [(#712)](https://github.com/PennyLaneAI/catalyst/pull/712)
+
Breaking changes
* Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`.
@@ -198,11 +201,12 @@
This release contains contributions from (in alphabetical order):
David Ittah,
-Mehrdad Malekmohammadi,
Erick Ochoa,
+Haochen Paul Wang,
Lee James O'Riordan,
+Mehrdad Malekmohammadi,
Raul Torres,
-Haochen Paul Wang.
+Sergei Mironov.
# Release 0.6.0
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index afcbb6429a..89792c0af7 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -33,7 +33,7 @@
from catalyst.utils.exceptions import CompileError
from catalyst.utils.filesystem import Directory
-from catalyst.utils.runtime import get_lib_path
+from catalyst.utils.runtime_environment import get_lib_path
package_root = os.path.dirname(__file__)
diff --git a/frontend/catalyst/device/__init__.py b/frontend/catalyst/device/__init__.py
index 6b9c1bd55c..a830cf7ec4 100644
--- a/frontend/catalyst/device/__init__.py
+++ b/frontend/catalyst/device/__init__.py
@@ -16,9 +16,18 @@
Internal API for the device module.
"""
-from catalyst.device.qjit_device import QJITDevice, QJITDeviceNewAPI
+from catalyst.device.qjit_device import (
+ BackendInfo,
+ QJITDevice,
+ QJITDeviceNewAPI,
+ extract_backend_info,
+ validate_device_capabilities,
+)
__all__ = (
"QJITDevice",
"QJITDeviceNewAPI",
+ "BackendInfo",
+ "extract_backend_info",
+ "validate_device_capabilities",
)
diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py
index c99f5db9c0..24be391cdc 100644
--- a/frontend/catalyst/device/qjit_device.py
+++ b/frontend/catalyst/device/qjit_device.py
@@ -16,10 +16,14 @@
This module contains device stubs for the old and new PennyLane device API, which facilitate
the application of decomposition and other device pre-processing routines.
"""
-
+import os
+import pathlib
+import platform
+import re
from copy import deepcopy
+from dataclasses import dataclass
from functools import partial
-from typing import Optional, Set
+from typing import Any, Dict, Optional, Set
import pennylane as qml
from pennylane.measurements import MidMeasureMP
@@ -32,13 +36,10 @@
)
from catalyst.utils.exceptions import CompileError
from catalyst.utils.patching import Patcher
-from catalyst.utils.runtime import BackendInfo, device_get_toml_config
+from catalyst.utils.runtime_environment import get_lib_path
from catalyst.utils.toml import (
DeviceCapabilities,
OperationProperties,
- ProgramFeatures,
- TOMLDocument,
- get_device_capabilities,
intersect_operations,
pennylane_operation_set,
)
@@ -84,6 +85,87 @@
for op in RUNTIME_OPERATIONS
}
+# TODO: This should be removed after implementing `get_c_interface`
+# for the following backend devices:
+SUPPORTED_RT_DEVICES = {
+ "lightning.qubit": ("LightningSimulator", "librtd_lightning"),
+ "lightning.kokkos": ("LightningKokkosSimulator", "librtd_lightning"),
+ "braket.aws.qubit": ("OpenQasmDevice", "librtd_openqasm"),
+ "braket.local.qubit": ("OpenQasmDevice", "librtd_openqasm"),
+}
+
+
+@dataclass
+class BackendInfo:
+ """Backend information"""
+
+ device_name: str
+ c_interface_name: str
+ lpath: str
+ kwargs: Dict[str, Any]
+
+
+def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo:
+ """Extract the backend info from a quantum device. The device is expected to carry a reference
+ to a valid TOML config file."""
+ # pylint: disable=too-many-branches
+
+ dname = device.name
+ if isinstance(device, qml.Device):
+ dname = device.short_name
+
+ device_name = ""
+ device_lpath = ""
+ device_kwargs = {}
+
+ if dname in SUPPORTED_RT_DEVICES:
+ # Support backend devices without `get_c_interface`
+ device_name = SUPPORTED_RT_DEVICES[dname][0]
+ device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR")
+ sys_platform = platform.system()
+
+ if sys_platform == "Linux":
+ device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so")
+ elif sys_platform == "Darwin": # pragma: no cover
+ device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib")
+ else: # pragma: no cover
+ raise NotImplementedError(f"Platform not supported: {sys_platform}")
+ elif hasattr(device, "get_c_interface"):
+ # Support third party devices with `get_c_interface`
+ device_name, device_lpath = device.get_c_interface()
+ else:
+ raise CompileError(f"The {dname} device does not provide C interface for compilation.")
+
+ if not pathlib.Path(device_lpath).is_file():
+ raise CompileError(f"Device at {device_lpath} cannot be found!")
+
+ if hasattr(device, "shots"):
+ if isinstance(device, qml.Device):
+ device_kwargs["shots"] = device.shots if device.shots else 0
+ else:
+ # TODO: support shot vectors
+ device_kwargs["shots"] = device.shots.total_shots if device.shots else 0
+
+ if dname == "braket.local.qubit": # pragma: no cover
+ device_kwargs["device_type"] = dname
+ device_kwargs["backend"] = (
+ # pylint: disable=protected-access
+ device._device._delegate.DEVICE_ID
+ )
+ elif dname == "braket.aws.qubit": # pragma: no cover
+ device_kwargs["device_type"] = dname
+ device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access
+ if device._s3_folder: # pylint: disable=protected-access
+ device_kwargs["s3_destination_folder"] = str(
+ device._s3_folder # pylint: disable=protected-access
+ )
+
+ for k, v in capabilities.options.items():
+ if hasattr(device, v):
+ device_kwargs[k] = getattr(device, v)
+
+ return BackendInfo(dname, device_name, device_lpath, device_kwargs)
+
def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Set[str]:
"""Calculate the set of supported quantum gates for the QJIT device from the gates
@@ -165,7 +247,7 @@ def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> S
def __init__(
self,
- target_config: TOMLDocument,
+ original_device_capabilities: DeviceCapabilities,
shots=None,
wires=None,
backend: Optional[BackendInfo] = None,
@@ -175,23 +257,18 @@ def __init__(
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
- device_name = backend.device_name if backend else "default"
- program_features = ProgramFeatures(shots is not None)
- target_device_capabilities = get_device_capabilities(
- target_config, program_features, device_name
- )
- self.capabilities = get_qjit_device_capabilities(target_device_capabilities)
+ self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities)
@property
def operations(self) -> Set[str]:
"""Get the device operations using PennyLane's syntax"""
- return pennylane_operation_set(self.capabilities.native_ops)
+ return pennylane_operation_set(self.qjit_capabilities.native_ops)
@property
def observables(self) -> Set[str]:
"""Get the device observables"""
- return pennylane_operation_set(self.capabilities.native_obs)
+ return pennylane_operation_set(self.qjit_capabilities.native_obs)
def apply(self, operations, **kwargs):
"""
@@ -270,6 +347,7 @@ class QJITDeviceNewAPI(qml.devices.Device):
def __init__(
self,
original_device,
+ original_device_capabilities: DeviceCapabilities,
backend: Optional[BackendInfo] = None,
):
self.original_device = original_device
@@ -285,29 +363,23 @@ def __init__(
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
- device_name = backend.device_name if backend else "default"
- target_config = device_get_toml_config(original_device)
- program_features = ProgramFeatures(original_device.shots is not None)
- target_device_capabilities = get_device_capabilities(
- target_config, program_features, device_name
- )
- self.capabilities = get_qjit_device_capabilities(target_device_capabilities)
+ self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities)
@property
def operations(self) -> Set[str]:
"""Get the device operations"""
- return pennylane_operation_set(self.capabilities.native_ops)
+ return pennylane_operation_set(self.qjit_capabilities.native_ops)
@property
def observables(self) -> Set[str]:
"""Get the device observables"""
- return pennylane_operation_set(self.capabilities.native_obs)
+ return pennylane_operation_set(self.qjit_capabilities.native_obs)
@property
def measurement_processes(self) -> Set[str]:
"""Get the device measurement processes"""
- return self.capabilities.measurement_processes
+ return self.qjit_capabilities.measurement_processes
def preprocess(
self,
@@ -334,3 +406,101 @@ def execute(self, circuits, execution_config):
Raises: RuntimeError
"""
raise RuntimeError("QJIT devices cannot execute tapes.")
+
+
+def filter_out_adjoint(operations):
+ """Remove Adjoint from operations.
+
+ Args:
+ operations (List[Str]): List of strings with names of supported operations
+
+ Returns:
+ List: A list of strings with names of supported operations with Adjoint and C gates
+ removed.
+ """
+ adjoint = re.compile(r"^Adjoint\(.*\)$")
+
+ def is_not_adj(op):
+ return not re.match(adjoint, op)
+
+ operations_no_adj = filter(is_not_adj, operations)
+ return set(operations_no_adj)
+
+
+def check_no_overlap(*args, device_name):
+ """Check items in *args are mutually exclusive.
+
+ Args:
+ *args (List[Str]): List of strings.
+ device_name (str): Device name for error reporting.
+
+ Raises:
+ CompileError
+ """
+ set_of_sets = [set(arg) for arg in args]
+ union = set.union(*set_of_sets)
+ len_of_sets = [len(arg) for arg in args]
+ if sum(len_of_sets) == len(union):
+ return
+
+ overlaps = set()
+ for s in set_of_sets:
+ overlaps.update(s - union)
+ union = union - s
+
+ msg = f"Device '{device_name}' has overlapping gates: {overlaps}"
+ raise CompileError(msg)
+
+
+def validate_device_capabilities(
+ device: qml.QubitDevice, device_capabilities: DeviceCapabilities
+) -> None:
+ """Validate configuration document against the device attributes.
+ Raise CompileError in case of mismatch:
+ * If device is not qjit-compatible.
+ * If configuration file does not exists.
+ * If decomposable, matrix, and native gates have some overlap.
+ * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and
+ ``device.observables``.
+
+ Args:
+ device (qml.Device): An instance of a quantum device.
+ config (TOMLDocument): A TOML document representation.
+
+ Raises: CompileError
+ """
+
+ if not device_capabilities.qjit_compatible_flag:
+ raise CompileError(
+ f"Attempting to compile program for incompatible device '{device.name}': "
+ f"Config is not marked as qjit-compatible"
+ )
+
+ device_name = device.short_name if isinstance(device, qml.Device) else device.name
+
+ native = pennylane_operation_set(device_capabilities.native_ops)
+ decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops)
+ matrix = pennylane_operation_set(device_capabilities.to_matrix_ops)
+
+ check_no_overlap(native, decomposable, matrix, device_name=device_name)
+
+ if hasattr(device, "operations") and hasattr(device, "observables"):
+ # For gates, we require strict match
+ device_gates = filter_out_adjoint(set(device.operations))
+ spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable))
+ if device_gates != spec_gates:
+ raise CompileError(
+ "Gates in qml.device.operations and specification file do not match.\n"
+ f"Gates that present only in the device: {device_gates - spec_gates}\n"
+ f"Gates that present only in spec: {spec_gates - device_gates}\n"
+ )
+
+ # For observables, we do not have `non-native` section in the config, so we check that
+ # device data supercedes the specification.
+ device_observables = set(device.observables)
+ spec_observables = pennylane_operation_set(device_capabilities.native_obs)
+ if (spec_observables - device_observables) != set():
+ raise CompileError(
+ "Observables in qml.device.observables and specification file do not match.\n"
+ f"Observables that present only in spec: {spec_observables - device_observables}\n"
+ )
diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py
index b09bc9f317..045d947c44 100644
--- a/frontend/catalyst/qfunc.py
+++ b/frontend/catalyst/qfunc.py
@@ -22,7 +22,13 @@
from jax.core import eval_jaxpr
from jax.tree_util import tree_flatten, tree_unflatten
-from catalyst.device import QJITDevice, QJITDeviceNewAPI
+from catalyst.device import (
+ BackendInfo,
+ QJITDevice,
+ QJITDeviceNewAPI,
+ extract_backend_info,
+ validate_device_capabilities,
+)
from catalyst.jax_extras import (
deduce_avals,
get_implicit_and_explicit_flat_args,
@@ -30,13 +36,11 @@
)
from catalyst.jax_primitives import func_p
from catalyst.jax_tracer import trace_quantum_function
-from catalyst.utils.runtime import (
- BackendInfo,
- device_get_toml_config,
- extract_backend_info,
- validate_config_with_device,
+from catalyst.utils.toml import (
+ DeviceCapabilities,
+ ProgramFeatures,
+ get_device_capabilities,
)
-from catalyst.utils.toml import TOMLDocument
class QFunc:
@@ -54,26 +58,35 @@ def __new__(cls):
raise NotImplementedError() # pragma: no-cover
@staticmethod
- def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> BackendInfo:
+ def extract_backend_info(
+ device: qml.QubitDevice, capabilities: DeviceCapabilities
+ ) -> BackendInfo:
"""Wrapper around extract_backend_info in the runtime module."""
- return extract_backend_info(device, config)
+ return extract_backend_info(device, capabilities)
- # pylint: disable=no-member
+ # pylint: disable=no-member, attribute-defined-outside-init
def __call__(self, *args, **kwargs):
assert isinstance(self, qml.QNode)
- config = device_get_toml_config(self.device)
- validate_config_with_device(self.device, config)
- backend_info = QFunc.extract_backend_info(self.device, config)
+ # TODO: Move the capability loading and validation to the device constructor when the
+ # support for old device api is dropped.
+ program_features = ProgramFeatures(self.device.shots is not None)
+ device_capabilities = get_device_capabilities(self.device, program_features)
+ backend_info = QFunc.extract_backend_info(self.device, device_capabilities)
+
+ # Validate decive operations against the declared capabilities
+ validate_device_capabilities(self.device, device_capabilities)
if isinstance(self.device, qml.devices.Device):
- device = QJITDeviceNewAPI(self.device, backend_info)
+ qjit_device = QJITDeviceNewAPI(self.device, device_capabilities, backend_info)
else:
- device = QJITDevice(config, self.device.shots, self.device.wires, backend_info)
+ qjit_device = QJITDevice(
+ device_capabilities, self.device.shots, self.device.wires, backend_info
+ )
def _eval_quantum(*args):
closed_jaxpr, out_type, out_tree = trace_quantum_function(
- self.func, device, args, kwargs, qnode=self
+ self.func, qjit_device, args, kwargs, qnode=self
)
args_expanded = get_implicit_and_explicit_flat_args(None, *args)
res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded)
diff --git a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py
index 7201f1dfea..c53a6d0c55 100644
--- a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py
+++ b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py
@@ -39,6 +39,7 @@
import pennylane as qml
from jax.tree_util import tree_unflatten
+from catalyst.device import BackendInfo
from catalyst.jax_primitives import (
AbstractObs,
adjoint_p,
@@ -75,7 +76,6 @@
from catalyst.qfunc import QFunc
from catalyst.utils.exceptions import CompileError
from catalyst.utils.patching import Patcher
-from catalyst.utils.runtime import BackendInfo
from .primitives import (
cuda_inst,
@@ -820,7 +820,7 @@ def get_jaxpr(self, *args):
an MLIR module
"""
- def cudaq_backend_info(device, _config) -> BackendInfo:
+ def cudaq_backend_info(device, _capabilities) -> BackendInfo:
"""The extract_backend_info should not be run by the cuda compiler as it is
catalyst-specific. We need to make this API a bit nicer for third-party compilers.
"""
diff --git a/frontend/catalyst/utils/runtime.py b/frontend/catalyst/utils/runtime.py
deleted file mode 100644
index 1fb3aed313..0000000000
--- a/frontend/catalyst/utils/runtime.py
+++ /dev/null
@@ -1,263 +0,0 @@
-# Copyright 2023 Xanadu Quantum Technologies Inc.
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-
-# http://www.apache.org/licenses/LICENSE-2.0
-
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Runtime utility methods.
-"""
-
-# pylint: disable=too-many-branches
-
-import os
-import pathlib
-import platform
-import re
-from dataclasses import dataclass
-from typing import Any, Dict
-
-import pennylane as qml
-
-from catalyst._configuration import INSTALLED
-from catalyst.utils.exceptions import CompileError
-from catalyst.utils.toml import (
- ProgramFeatures,
- TOMLDocument,
- get_device_capabilities,
- pennylane_operation_set,
- read_toml_file,
-)
-
-package_root = os.path.dirname(__file__)
-
-
-# Default paths to dep libraries
-DEFAULT_LIB_PATHS = {
- "llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"),
- "runtime": os.path.join(package_root, "../../../runtime/build/lib"),
- "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"),
- "oqc_runtime": os.path.join(package_root, "../../catalyst/third_party/oqc/src/build"),
-}
-
-
-# TODO: This should be removed after implementing `get_c_interface`
-# for the following backend devices:
-SUPPORTED_RT_DEVICES = {
- "lightning.qubit": ("LightningSimulator", "librtd_lightning"),
- "lightning.kokkos": ("LightningKokkosSimulator", "librtd_lightning"),
- "braket.aws.qubit": ("OpenQasmDevice", "librtd_openqasm"),
- "braket.local.qubit": ("OpenQasmDevice", "librtd_openqasm"),
-}
-
-
-def get_lib_path(project, env_var):
- """Get the library path."""
- if INSTALLED:
- return os.path.join(package_root, "..", "lib") # pragma: no cover
- return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, ""))
-
-
-def check_no_overlap(*args, device_name):
- """Check items in *args are mutually exclusive.
-
- Args:
- *args (List[Str]): List of strings.
- device_name (str): Device name for error reporting.
-
- Raises:
- CompileError
- """
- set_of_sets = [set(arg) for arg in args]
- union = set.union(*set_of_sets)
- len_of_sets = [len(arg) for arg in args]
- if sum(len_of_sets) == len(union):
- return
-
- overlaps = set()
- for s in set_of_sets:
- overlaps.update(s - union)
- union = union - s
-
- msg = f"Device '{device_name}' has overlapping gates: {overlaps}"
- raise CompileError(msg)
-
-
-def filter_out_adjoint(operations):
- """Remove Adjoint from operations.
-
- Args:
- operations (List[Str]): List of strings with names of supported operations
-
- Returns:
- List: A list of strings with names of supported operations with Adjoint and C gates
- removed.
- """
- adjoint = re.compile(r"^Adjoint\(.*\)$")
-
- def is_not_adj(op):
- return not re.match(adjoint, op)
-
- operations_no_adj = filter(is_not_adj, operations)
- return set(operations_no_adj)
-
-
-def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) -> None:
- """Validate configuration document against the device attributes.
- Raise CompileError in case of mismatch:
- * If device is not qjit-compatible.
- * If configuration file does not exists.
- * If decomposable, matrix, and native gates have some overlap.
- * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and
- ``device.observables``.
-
- Args:
- device (qml.Device): An instance of a quantum device.
- config (TOMLDocument): A TOML document representation.
-
- Raises: CompileError
- """
-
- if not config["compilation"]["qjit_compatible"]:
- raise CompileError(
- f"Attempting to compile program for incompatible device '{device.name}': "
- f"Config is not marked as qjit-compatible"
- )
-
- device_name = device.short_name if isinstance(device, qml.Device) else device.name
- program_features = ProgramFeatures(device.shots is not None)
- device_capabilities = get_device_capabilities(config, program_features, device_name)
-
- native = pennylane_operation_set(device_capabilities.native_ops)
- decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops)
- matrix = pennylane_operation_set(device_capabilities.to_matrix_ops)
-
- check_no_overlap(native, decomposable, matrix, device_name=device_name)
-
- if hasattr(device, "operations") and hasattr(device, "observables"):
- # For gates, we require strict match
- device_gates = filter_out_adjoint(set(device.operations))
- spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable))
- if device_gates != spec_gates:
- raise CompileError(
- "Gates in qml.device.operations and specification file do not match.\n"
- f"Gates that present only in the device: {device_gates - spec_gates}\n"
- f"Gates that present only in spec: {spec_gates - device_gates}\n"
- )
-
- # For observables, we do not have `non-native` section in the config, so we check that
- # device data supercedes the specification.
- device_observables = set(device.observables)
- spec_observables = pennylane_operation_set(device_capabilities.native_obs)
- if (spec_observables - device_observables) != set():
- raise CompileError(
- "Observables in qml.device.observables and specification file do not match.\n"
- f"Observables that present only in spec: {spec_observables - device_observables}\n"
- )
-
-
-def device_get_toml_config(device) -> TOMLDocument:
- """Get the contents of the device config file."""
- if hasattr(device, "config"):
- # The expected case: device specifies its own config.
- toml_file = device.config
- else:
- # TODO: Remove this section when `qml.Device`s are guaranteed to have their own config file
- # field.
- device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR"))
-
- name = device.short_name if isinstance(device, qml.Device) else device.name
- # The toml files name convention we follow is to replace
- # the dots with underscores in the device short name.
- toml_file_name = name.replace(".", "_") + ".toml"
- # And they are currently saved in the following directory.
- toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name
-
- try:
- config = read_toml_file(toml_file)
- except FileNotFoundError as e:
- raise CompileError(
- "Attempting to compile program for incompatible device: "
- f"Config file ({toml_file}) does not exist"
- ) from e
-
- return config
-
-
-@dataclass
-class BackendInfo:
- """Backend information"""
-
- device_name: str
- c_interface_name: str
- lpath: str
- kwargs: Dict[str, Any]
-
-
-def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> BackendInfo:
- """Extract the backend info from a quantum device. The device is expected to carry a reference
- to a valid TOML config file."""
-
- dname = device.name
- if isinstance(device, qml.Device):
- dname = device.short_name
-
- device_name = ""
- device_lpath = ""
- device_kwargs = {}
-
- if dname in SUPPORTED_RT_DEVICES:
- # Support backend devices without `get_c_interface`
- device_name = SUPPORTED_RT_DEVICES[dname][0]
- device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR")
- sys_platform = platform.system()
-
- if sys_platform == "Linux":
- device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so")
- elif sys_platform == "Darwin": # pragma: no cover
- device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib")
- else: # pragma: no cover
- raise NotImplementedError(f"Platform not supported: {sys_platform}")
- elif hasattr(device, "get_c_interface"):
- # Support third party devices with `get_c_interface`
- device_name, device_lpath = device.get_c_interface()
- else:
- raise CompileError(f"The {dname} device does not provide C interface for compilation.")
-
- if not pathlib.Path(device_lpath).is_file():
- raise CompileError(f"Device at {device_lpath} cannot be found!")
-
- if hasattr(device, "shots"):
- if isinstance(device, qml.Device):
- device_kwargs["shots"] = device.shots if device.shots else 0
- else:
- # TODO: support shot vectors
- device_kwargs["shots"] = device.shots.total_shots if device.shots else 0
-
- if dname == "braket.local.qubit": # pragma: no cover
- device_kwargs["device_type"] = dname
- device_kwargs["backend"] = (
- # pylint: disable=protected-access
- device._device._delegate.DEVICE_ID
- )
- elif dname == "braket.aws.qubit": # pragma: no cover
- device_kwargs["device_type"] = dname
- device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access
- if device._s3_folder: # pylint: disable=protected-access
- device_kwargs["s3_destination_folder"] = str(
- device._s3_folder # pylint: disable=protected-access
- )
-
- options = config.get("options", {})
- for k, v in options.items():
- if hasattr(device, v):
- device_kwargs[k] = getattr(device, v)
-
- return BackendInfo(dname, device_name, device_lpath, device_kwargs)
diff --git a/frontend/catalyst/utils/runtime_environment.py b/frontend/catalyst/utils/runtime_environment.py
new file mode 100644
index 0000000000..a75b466809
--- /dev/null
+++ b/frontend/catalyst/utils/runtime_environment.py
@@ -0,0 +1,38 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utility code for keeping paths
+"""
+import os
+import os.path
+
+from catalyst._configuration import INSTALLED
+
+package_root = os.path.dirname(__file__)
+
+# Default paths to dep libraries
+DEFAULT_LIB_PATHS = {
+ "llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"),
+ "runtime": os.path.join(package_root, "../../../runtime/build/lib"),
+ "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"),
+ "oqc_runtime": os.path.join(package_root, "../../catalyst/third_party/oqc/src/build"),
+}
+
+
+def get_lib_path(project, env_var):
+ """Get the library path."""
+ if INSTALLED:
+ return os.path.join(package_root, "..", "lib") # pragma: no cover
+ return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, ""))
diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py
index 8044962889..1806d23c2c 100644
--- a/frontend/catalyst/utils/toml.py
+++ b/frontend/catalyst/utils/toml.py
@@ -16,12 +16,16 @@
"""
import importlib.util
+import pathlib
from dataclasses import dataclass
from functools import reduce
from itertools import repeat
-from typing import Any, Dict, List, Set
+from typing import Any, Dict, List, Optional, Set
+
+import pennylane as qml
from catalyst.utils.exceptions import CompileError
+from catalyst.utils.runtime_environment import get_lib_path
# TODO:
# Once Python version 3.11 is the oldest supported Python version, we can remove tomlkit
@@ -33,8 +37,7 @@
tomlkit = importlib.util.find_spec("tomlkit")
# We need at least one of these to make sure we can read toml files.
if tomllib is None and tomlkit is None: # pragma: nocover
- msg = "Either tomllib or tomlkit need to be installed."
- raise ImportError(msg)
+ raise ImportError("Either tomllib or tomlkit need to be installed.")
# Give preference to tomllib
if tomllib: # pragma: nocover
@@ -82,9 +85,11 @@ class DeviceCapabilities: # pylint: disable=too-many-instance-attributes
to_matrix_ops: Dict[str, OperationProperties]
native_obs: Dict[str, OperationProperties]
measurement_processes: Set[str]
+ qjit_compatible_flag: bool
mid_circuit_measurement_flag: bool
runtime_code_generation_flag: bool
dynamic_qubit_management_flag: bool
+ options: Dict[str, bool]
def intersect_operations(
@@ -112,11 +117,16 @@ class ProgramFeatures:
shots_present: bool
-def check_compilation_flag(config: TOMLDocument, flag_name: str) -> bool:
- """Checks the flag in the toml document 'compilation' section."""
+def get_compilation_flag(config: TOMLDocument, flag_name: str) -> bool:
+ """Get the flag in the toml document 'compilation' section."""
return bool(config.get("compilation", {}).get(flag_name, False))
+def get_options(config: TOMLDocument) -> Dict[str, str]:
+ """Get custom options sections"""
+ return {str(k): str(v) for k, v in config.get("options", {}).items()}
+
+
def check_quantum_control_flag(config: TOMLDocument) -> bool:
"""Check the control flag. Only exists in toml config schema 1"""
schema = int(config["schema"])
@@ -300,7 +310,7 @@ def patch_schema1_collections(
for op, props in native_gate_props.items():
props.controllable = op not in gates_to_be_decomposed_if_controlled
- supports_adjoint = check_compilation_flag(config, "quantum_adjoint")
+ supports_adjoint = get_compilation_flag(config, "quantum_adjoint")
if supports_adjoint:
# Makr all gates as invertibles
for props in native_gate_props.values():
@@ -326,10 +336,54 @@ def patch_schema1_collections(
decomp_props.pop("ControlledPhaseShift")
+def get_device_toml_config(device) -> TOMLDocument:
+ """Get the contents of the device config file."""
+ if hasattr(device, "config"):
+ # The expected case: device specifies its own config.
+ toml_file = device.config
+ else:
+ # TODO: Remove this section when `qml.Device`s are guaranteed to have their own config file
+ # field.
+ device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR"))
+
+ name = device.short_name if isinstance(device, qml.Device) else device.name
+ # The toml files name convention we follow is to replace
+ # the dots with underscores in the device short name.
+ toml_file_name = name.replace(".", "_") + ".toml"
+ # And they are currently saved in the following directory.
+ toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name
+
+ try:
+ config = read_toml_file(toml_file)
+ except FileNotFoundError as e:
+ raise CompileError(
+ "Attempting to compile program for incompatible device: "
+ f"Config file ({toml_file}) does not exist"
+ ) from e
+
+ return config
+
+
def get_device_capabilities(
+ device, program_features: Optional[ProgramFeatures] = None
+) -> DeviceCapabilities:
+ """Get or load DeviceCapabilities structure from device"""
+
+ if hasattr(device, "qjit_capabilities"):
+ return device.qjit_capabilities
+ else:
+ program_features = (
+ program_features if program_features else ProgramFeatures(device.shots is not None)
+ )
+ device_name = device.short_name if isinstance(device, qml.Device) else device.name
+ device_config = get_device_toml_config(device)
+ return load_device_capabilities(device_config, program_features, device_name)
+
+
+def load_device_capabilities(
config: TOMLDocument, program_features: ProgramFeatures, device_name: str
) -> DeviceCapabilities:
- """Load TOML document into the DeviceCapabilities structure"""
+ """Load device capabilities from device config"""
schema = int(config["schema"])
@@ -369,7 +423,9 @@ def get_device_capabilities(
to_matrix_ops=matrix_decomp_props,
native_obs=observable_props,
measurement_processes=measurements_props,
- mid_circuit_measurement_flag=check_compilation_flag(config, "mid_circuit_measurement"),
- runtime_code_generation_flag=check_compilation_flag(config, "runtime_code_generation"),
- dynamic_qubit_management_flag=check_compilation_flag(config, "dynamic_qubit_management"),
+ qjit_compatible_flag=get_compilation_flag(config, "qjit_compatible"),
+ mid_circuit_measurement_flag=get_compilation_flag(config, "mid_circuit_measurement"),
+ runtime_code_generation_flag=get_compilation_flag(config, "runtime_code_generation"),
+ dynamic_qubit_management_flag=get_compilation_flag(config, "dynamic_qubit_management"),
+ options=get_options(config),
)
diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py
index c0375a4356..2c1c0ce57b 100644
--- a/frontend/test/lit/test_decomposition.py
+++ b/frontend/test/lit/test_decomposition.py
@@ -15,13 +15,17 @@
# RUN: %PYTHON %s | FileCheck %s
# pylint: disable=line-too-long
-import os
-import tempfile
+from copy import deepcopy
import jax
import pennylane as qml
from catalyst import cond, for_loop, measure, qjit, while_loop
+from catalyst.utils.toml import (
+ ProgramFeatures,
+ get_device_capabilities,
+ pennylane_operation_set,
+)
def get_custom_device_without(num_wires, discards):
@@ -37,8 +41,6 @@ class CustomDevice(qml.QubitDevice):
author = "Tester"
lightning_device = qml.device("lightning.qubit", wires=0)
- operations = lightning_device.operations.copy() - discards
- observables = lightning_device.observables.copy()
config = None
backend_name = "default"
@@ -47,54 +49,57 @@ class CustomDevice(qml.QubitDevice):
def __init__(self, shots=None, wires=None):
super().__init__(wires=wires, shots=shots)
- self.toml_file = None
+ program_features = ProgramFeatures(shots_present=self.shots is not None)
+ lightning_capabilities = get_device_capabilities(
+ self.lightning_device, program_features
+ )
+ custom_capabilities = deepcopy(lightning_capabilities)
+ for gate in discards:
+ if gate in custom_capabilities.native_ops:
+ custom_capabilities.native_ops.pop(gate)
+ if gate in custom_capabilities.to_decomp_ops:
+ custom_capabilities.to_decomp_ops.pop(gate)
+ if gate in custom_capabilities.to_matrix_ops:
+ custom_capabilities.to_matrix_ops.pop(gate)
+ self.qjit_capabilities = custom_capabilities
def apply(self, operations, **kwargs):
"""Unused"""
raise RuntimeError("Only C/C++ interface is defined")
- def __enter__(self, *args, **kwargs):
- lightning_toml = self.lightning_device.config
- with open(lightning_toml, mode="r", encoding="UTF-8") as f:
- toml_contents = f.readlines()
+ @property
+ def operations(self):
+ """Return operations using PennyLane's C(.) syntax"""
+ return (
+ pennylane_operation_set(self.qjit_capabilities.native_ops)
+ | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops)
+ | pennylane_operation_set(self.qjit_capabilities.to_matrix_ops)
+ )
- # TODO: update once schema 2 is merged
- updated_toml_contents = []
- for line in toml_contents:
- if any(f'"{gate}",' in line for gate in discards):
- continue
- updated_toml_contents.append(line)
-
- self.toml_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
- self.toml_file.writelines(updated_toml_contents)
- self.toml_file.close() # close for now without deleting
-
- self.config = self.toml_file.name
- return self
-
- def __exit__(self, *args, **kwargs):
- os.unlink(self.toml_file.name)
- self.config = None
+ @property
+ def observables(self):
+ """Return PennyLane observables"""
+ return pennylane_operation_set(self.qjit_capabilities.native_obs)
return CustomDevice(wires=num_wires)
def test_decompose_multicontrolledx():
"""Test decomposition of MultiControlledX."""
- with get_custom_device_without(5, {"MultiControlledX"}) as dev:
-
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: public @jit_decompose_multicontrolled_x1
- def decompose_multicontrolled_x1(theta: float):
- qml.RX(theta, wires=[0])
- # CHECK-NOT: name = "MultiControlledX"
- # CHECK: quantum.unitary
- # CHECK-NOT: name = "MultiControlledX"
- qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4])
- return qml.state()
+ dev = get_custom_device_without(5, {"MultiControlledX"})
- print(decompose_multicontrolled_x1.mlir)
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: public @jit_decompose_multicontrolled_x1
+ def decompose_multicontrolled_x1(theta: float):
+ qml.RX(theta, wires=[0])
+ # CHECK-NOT: name = "MultiControlledX"
+ # CHECK: quantum.unitary
+ # CHECK-NOT: name = "MultiControlledX"
+ qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4])
+ return qml.state()
+
+ print(decompose_multicontrolled_x1.mlir)
test_decompose_multicontrolledx()
@@ -102,25 +107,25 @@ def decompose_multicontrolled_x1(theta: float):
def test_decompose_multicontrolledx_in_conditional():
"""Test decomposition of MultiControlledX in conditional."""
- with get_custom_device_without(5, {"MultiControlledX"}) as dev:
-
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: @jit_decompose_multicontrolled_x2
- def decompose_multicontrolled_x2(theta: float, n: int):
- qml.RX(theta, wires=[0])
-
- # CHECK-NOT: name = "MultiControlledX"
- # CHECK: quantum.unitary
- # CHECK-NOT: name = "MultiControlledX"
- @cond(n > 1)
- def cond_fn():
- qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4])
+ dev = get_custom_device_without(5, {"MultiControlledX"})
+
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: @jit_decompose_multicontrolled_x2
+ def decompose_multicontrolled_x2(theta: float, n: int):
+ qml.RX(theta, wires=[0])
+
+ # CHECK-NOT: name = "MultiControlledX"
+ # CHECK: quantum.unitary
+ # CHECK-NOT: name = "MultiControlledX"
+ @cond(n > 1)
+ def cond_fn():
+ qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4])
- cond_fn()
- return qml.state()
+ cond_fn()
+ return qml.state()
- print(decompose_multicontrolled_x2.mlir)
+ print(decompose_multicontrolled_x2.mlir)
test_decompose_multicontrolledx_in_conditional()
@@ -128,26 +133,26 @@ def cond_fn():
def test_decompose_multicontrolledx_in_while_loop():
"""Test decomposition of MultiControlledX in while loop."""
- with get_custom_device_without(5, {"MultiControlledX"}) as dev:
-
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: @jit_decompose_multicontrolled_x3
- def decompose_multicontrolled_x3(theta: float, n: int):
- qml.RX(theta, wires=[0])
-
- # CHECK-NOT: name = "MultiControlledX"
- # CHECK: quantum.unitary
- # CHECK-NOT: name = "MultiControlledX"
- @while_loop(lambda v: v[0] < 10)
- def loop(v):
- qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4])
- return v[0] + 1, v[1]
+ dev = get_custom_device_without(5, {"MultiControlledX"})
+
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: @jit_decompose_multicontrolled_x3
+ def decompose_multicontrolled_x3(theta: float, n: int):
+ qml.RX(theta, wires=[0])
+
+ # CHECK-NOT: name = "MultiControlledX"
+ # CHECK: quantum.unitary
+ # CHECK-NOT: name = "MultiControlledX"
+ @while_loop(lambda v: v[0] < 10)
+ def loop(v):
+ qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4])
+ return v[0] + 1, v[1]
- loop((0, n))
- return qml.state()
+ loop((0, n))
+ return qml.state()
- print(decompose_multicontrolled_x3.mlir)
+ print(decompose_multicontrolled_x3.mlir)
test_decompose_multicontrolledx_in_while_loop()
@@ -155,25 +160,25 @@ def loop(v):
def test_decompose_multicontrolledx_in_for_loop():
"""Test decomposition of MultiControlledX in for loop."""
- with get_custom_device_without(5, {"MultiControlledX"}) as dev:
-
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: @jit_decompose_multicontrolled_x4
- def decompose_multicontrolled_x4(theta: float, n: int):
- qml.RX(theta, wires=[0])
-
- # CHECK-NOT: name = "MultiControlledX"
- # CHECK: quantum.unitary
- # CHECK-NOT: name = "MultiControlledX"
- @for_loop(0, n, 1)
- def loop(_):
- qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4])
+ dev = get_custom_device_without(5, {"MultiControlledX"})
+
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: @jit_decompose_multicontrolled_x4
+ def decompose_multicontrolled_x4(theta: float, n: int):
+ qml.RX(theta, wires=[0])
+
+ # CHECK-NOT: name = "MultiControlledX"
+ # CHECK: quantum.unitary
+ # CHECK-NOT: name = "MultiControlledX"
+ @for_loop(0, n, 1)
+ def loop(_):
+ qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4])
- loop()
- return qml.state()
+ loop()
+ return qml.state()
- print(decompose_multicontrolled_x4.mlir)
+ print(decompose_multicontrolled_x4.mlir)
test_decompose_multicontrolledx_in_for_loop()
@@ -181,29 +186,29 @@ def loop(_):
def test_decompose_rot():
"""Test decomposition of Rot gate."""
- with get_custom_device_without(1, {"Rot", "C(Rot)"}) as dev:
-
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: public @jit_decompose_rot
- def decompose_rot(phi: float, theta: float, omega: float):
- # CHECK-NOT: name = "Rot"
- # CHECK: [[phi:%.+]] = tensor.extract %arg0
- # CHECK-NOT: name = "Rot"
- # CHECK: {{%.+}} = quantum.custom "RZ"([[phi]])
- # CHECK-NOT: name = "Rot"
- # CHECK: [[theta:%.+]] = tensor.extract %arg1
- # CHECK-NOT: name = "Rot"
- # CHECK: {{%.+}} = quantum.custom "RY"([[theta]])
- # CHECK-NOT: name = "Rot"
- # CHECK: [[omega:%.+]] = tensor.extract %arg2
- # CHECK-NOT: name = "Rot"
- # CHECK: {{%.+}} = quantum.custom "RZ"([[omega]])
- # CHECK-NOT: name = "Rot"
- qml.Rot(phi, theta, omega, wires=0)
- return measure(wires=0)
-
- print(decompose_rot.mlir)
+ dev = get_custom_device_without(1, {"Rot", "C(Rot)"})
+
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: public @jit_decompose_rot
+ def decompose_rot(phi: float, theta: float, omega: float):
+ # CHECK-NOT: name = "Rot"
+ # CHECK: [[phi:%.+]] = tensor.extract %arg0
+ # CHECK-NOT: name = "Rot"
+ # CHECK: {{%.+}} = quantum.custom "RZ"([[phi]])
+ # CHECK-NOT: name = "Rot"
+ # CHECK: [[theta:%.+]] = tensor.extract %arg1
+ # CHECK-NOT: name = "Rot"
+ # CHECK: {{%.+}} = quantum.custom "RY"([[theta]])
+ # CHECK-NOT: name = "Rot"
+ # CHECK: [[omega:%.+]] = tensor.extract %arg2
+ # CHECK-NOT: name = "Rot"
+ # CHECK: {{%.+}} = quantum.custom "RZ"([[omega]])
+ # CHECK-NOT: name = "Rot"
+ qml.Rot(phi, theta, omega, wires=0)
+ return measure(wires=0)
+
+ print(decompose_rot.mlir)
test_decompose_rot()
@@ -211,21 +216,21 @@ def decompose_rot(phi: float, theta: float, omega: float):
def test_decompose_s():
"""Test decomposition of S gate."""
- with get_custom_device_without(1, {"S", "C(S)"}) as dev:
+ dev = get_custom_device_without(1, {"S", "C(S)"})
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: public @jit_decompose_s
- def decompose_s():
- # CHECK-NOT: name="S"
- # CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64
- # CHECK-NOT: name = "S"
- # CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]])
- # CHECK-NOT: name = "S"
- qml.S(wires=0)
- return measure(wires=0)
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: public @jit_decompose_s
+ def decompose_s():
+ # CHECK-NOT: name="S"
+ # CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64
+ # CHECK-NOT: name = "S"
+ # CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]])
+ # CHECK-NOT: name = "S"
+ qml.S(wires=0)
+ return measure(wires=0)
- print(decompose_s.mlir)
+ print(decompose_s.mlir)
test_decompose_s()
@@ -233,21 +238,21 @@ def decompose_s():
def test_decompose_qubitunitary():
"""Test decomposition of QubitUnitary"""
- with get_custom_device_without(1, {"QubitUnitary"}) as dev:
+ dev = get_custom_device_without(1, {"QubitUnitary"})
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: public @jit_decompose_qubit_unitary
- def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)):
- # CHECK-NOT: name = "QubitUnitary"
- # CHECK: quantum.custom "RZ"
- # CHECK: quantum.custom "RY"
- # CHECK: quantum.custom "RZ"
- # CHECK-NOT: name = "QubitUnitary"
- qml.QubitUnitary(U, wires=0)
- return measure(wires=0)
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: public @jit_decompose_qubit_unitary
+ def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)):
+ # CHECK-NOT: name = "QubitUnitary"
+ # CHECK: quantum.custom "RZ"
+ # CHECK: quantum.custom "RY"
+ # CHECK: quantum.custom "RZ"
+ # CHECK-NOT: name = "QubitUnitary"
+ qml.QubitUnitary(U, wires=0)
+ return measure(wires=0)
- print(decompose_qubit_unitary.mlir)
+ print(decompose_qubit_unitary.mlir)
test_decompose_qubitunitary()
@@ -255,47 +260,47 @@ def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)):
def test_decompose_singleexcitationplus():
"""Test decomposition of single excitation plus."""
- with get_custom_device_without(2, {"SingleExcitationPlus", "C(SingleExcitationPlus)"}) as dev:
-
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: public @jit_decompose_singleexcitationplus
- def decompose_singleexcitationplus(theta: float):
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[a_scalar_tensor_float_2:%.+]] = stablehlo.constant dense<2.{{[0]+}}e+00>
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[s0q1:%.+]] = quantum.custom "PauliX"
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[s0q0:%.+]] = quantum.custom "PauliX"
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[a_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]]
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[a_theta_div_2_scalar:%.+]] = tensor.extract [[a_theta_div_2]]
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[s1:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[a_theta_div_2_scalar]]) [[s0q1]], [[s0q0]]
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[s2q1:%.+]] = quantum.custom "PauliX"() [[s1]]#1
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[s2q0:%.+]] = quantum.custom "PauliX"() [[s1]]#0
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[b_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]]
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[b_theta_div_2_scalar:%.+]] = tensor.extract [[b_theta_div_2]]
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[s3:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[b_theta_div_2_scalar]]) [[s2q1]], [[s2q0]]
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[s4:%.+]]:2 = quantum.custom "CNOT"() [[s3]]#0, [[s3]]#1
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[theta_scalar:%.+]] = tensor.extract %arg0
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[s5:%.+]]:2 = quantum.custom "CRY"([[theta_scalar]]) [[s4]]#1, [[s4]]#0
- # CHECK-NOT: name = "SingleExcitationPlus"
- # CHECK: [[s6:%.+]]:2 = quantum.custom "CNOT"() [[s5]]#1, [[s5]]#0
- # CHECK-NOT: name = "SingleExcitationPlus"
- qml.SingleExcitationPlus(theta, wires=[0, 1])
- return measure(wires=0)
-
- print(decompose_singleexcitationplus.mlir)
+ dev = get_custom_device_without(2, {"SingleExcitationPlus", "C(SingleExcitationPlus)"})
+
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: public @jit_decompose_singleexcitationplus
+ def decompose_singleexcitationplus(theta: float):
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[a_scalar_tensor_float_2:%.+]] = stablehlo.constant dense<2.{{[0]+}}e+00>
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[s0q1:%.+]] = quantum.custom "PauliX"
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[s0q0:%.+]] = quantum.custom "PauliX"
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[a_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]]
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[a_theta_div_2_scalar:%.+]] = tensor.extract [[a_theta_div_2]]
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[s1:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[a_theta_div_2_scalar]]) [[s0q1]], [[s0q0]]
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[s2q1:%.+]] = quantum.custom "PauliX"() [[s1]]#1
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[s2q0:%.+]] = quantum.custom "PauliX"() [[s1]]#0
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[b_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]]
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[b_theta_div_2_scalar:%.+]] = tensor.extract [[b_theta_div_2]]
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[s3:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[b_theta_div_2_scalar]]) [[s2q1]], [[s2q0]]
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[s4:%.+]]:2 = quantum.custom "CNOT"() [[s3]]#0, [[s3]]#1
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[theta_scalar:%.+]] = tensor.extract %arg0
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[s5:%.+]]:2 = quantum.custom "CRY"([[theta_scalar]]) [[s4]]#1, [[s4]]#0
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ # CHECK: [[s6:%.+]]:2 = quantum.custom "CNOT"() [[s5]]#1, [[s5]]#0
+ # CHECK-NOT: name = "SingleExcitationPlus"
+ qml.SingleExcitationPlus(theta, wires=[0, 1])
+ return measure(wires=0)
+
+ print(decompose_singleexcitationplus.mlir)
test_decompose_singleexcitationplus()
diff --git a/frontend/test/lit/test_quantum_control.py b/frontend/test/lit/test_quantum_control.py
index e7125e7eb7..15cc876f93 100644
--- a/frontend/test/lit/test_quantum_control.py
+++ b/frontend/test/lit/test_quantum_control.py
@@ -15,13 +15,18 @@
# RUN: %PYTHON %s | FileCheck %s
""" Test the lowering cases involving quantum control """
-import os
-import tempfile
+from copy import deepcopy
import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit
+from catalyst.utils.toml import (
+ OperationProperties,
+ ProgramFeatures,
+ get_device_capabilities,
+ pennylane_operation_set,
+)
def get_custom_qjit_device(num_wires, discards, additions):
@@ -37,70 +42,61 @@ class CustomDevice(qml.QubitDevice):
author = "Tester"
lightning_device = qml.device("lightning.qubit", wires=0)
- operations = lightning_device.operations.copy() - discards | additions
- observables = lightning_device.observables.copy()
- config = None
backend_name = "default"
backend_lib = "default"
backend_kwargs = {}
def __init__(self, shots=None, wires=None):
super().__init__(wires=wires, shots=shots)
- self.toml_file = None
+ program_features = ProgramFeatures(shots_present=shots is not None)
+ lightning_capabilities = get_device_capabilities(
+ self.lightning_device, program_features
+ )
+ custom_capabilities = deepcopy(lightning_capabilities)
+ for gate in discards:
+ custom_capabilities.native_ops.pop(gate)
+ custom_capabilities.native_ops.update(additions)
+ self.qjit_capabilities = custom_capabilities
+
+ @property
+ def operations(self):
+ """Get PennyLane operations."""
+ return (
+ pennylane_operation_set(self.qjit_capabilities.native_ops)
+ | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops)
+ | pennylane_operation_set(self.qjit_capabilities.to_matrix_ops)
+ )
+
+ @property
+ def observables(self):
+ """Get PennyLane observables."""
+ return pennylane_operation_set(self.qjit_capabilities.native_obs)
def apply(self, operations, **kwargs):
"""Unused"""
raise RuntimeError("Only C/C++ interface is defined")
- def __enter__(self, *args, **kwargs):
- lightning_toml = self.lightning_device.config
- with open(lightning_toml, mode="r", encoding="UTF-8") as f:
- toml_contents = f.readlines()
-
- # TODO: update once schema 2 is merged
- updated_toml_contents = []
- for line in toml_contents:
- if any(f'"{gate}",' in line for gate in discards):
- continue
-
- updated_toml_contents.append(line)
- if "native = [" in line:
- for gate in additions:
- if not gate.startswith("C("):
- updated_toml_contents.append(f' "{gate}",\n')
-
- self.toml_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
- self.toml_file.writelines(updated_toml_contents)
- self.toml_file.close() # close for now without deleting
-
- self.config = self.toml_file.name
- return self
-
- def __exit__(self, *args, **kwargs):
- os.unlink(self.toml_file.name)
- self.config = None
-
return CustomDevice(wires=num_wires)
def test_named_controlled():
"""Test that named-controlled operations are passed as-is."""
- with get_custom_qjit_device(2, set(), set()) as dev:
+ dev = get_custom_qjit_device(2, set(), set())
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: public @jit_named_controlled
- def named_controlled():
- # CHECK: quantum.custom "CNOT"
- qml.CNOT(wires=[0, 1])
- # CHECK: quantum.custom "CY"
- qml.CY(wires=[0, 1])
- # CHECK: quantum.custom "CZ"
- qml.CZ(wires=[0, 1])
- return qml.state()
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: public @jit_named_controlled
+ def named_controlled():
+ # CHECK: quantum.custom "CNOT"
+ qml.CNOT(wires=[0, 1])
+ # CHECK: quantum.custom "CY"
+ qml.CY(wires=[0, 1])
+ # CHECK: quantum.custom "CZ"
+ qml.CZ(wires=[0, 1])
+ return qml.state()
- print(named_controlled.mlir)
+ print(named_controlled.mlir)
test_named_controlled()
@@ -108,19 +104,19 @@ def named_controlled():
def test_native_controlled_custom():
"""Test native control of a custom operation."""
- with get_custom_qjit_device(3, {"CRot"}, {"Rot", "C(Rot)"}) as dev:
+ dev = get_custom_qjit_device(3, set(), {"Rot": OperationProperties(True, True, False)})
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: public @jit_native_controlled
- def native_controlled():
- # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot"
- # CHECK-SAME: ctrls
- # CHECK-SAME: ctrlvals(%true, %true)
- qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2])
- return qml.state()
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: public @jit_native_controlled
+ def native_controlled():
+ # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot"
+ # CHECK-SAME: ctrls
+ # CHECK-SAME: ctrlvals(%true, %true)
+ qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2])
+ return qml.state()
- print(native_controlled.mlir)
+ print(native_controlled.mlir)
test_native_controlled_custom()
@@ -128,31 +124,31 @@ def native_controlled():
def test_native_controlled_unitary():
"""Test native control of the unitary operation."""
- with get_custom_qjit_device(4, set(), set()) as dev:
-
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: public @jit_native_controlled_unitary
- def native_controlled_unitary():
- # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:3 = quantum.unitary
- # CHECK-SAME: ctrls
- # CHECK-SAME: ctrlvals(%true, %true, %true)
- qml.ctrl(
- qml.QubitUnitary(
- jnp.array(
- [
- [0.70710678 + 0.0j, 0.70710678 + 0.0j],
- [0.70710678 + 0.0j, -0.70710678 + 0.0j],
- ],
- dtype=jnp.complex128,
- ),
- wires=[0],
+ dev = get_custom_qjit_device(4, set(), set())
+
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: public @jit_native_controlled_unitary
+ def native_controlled_unitary():
+ # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:3 = quantum.unitary
+ # CHECK-SAME: ctrls
+ # CHECK-SAME: ctrlvals(%true, %true, %true)
+ qml.ctrl(
+ qml.QubitUnitary(
+ jnp.array(
+ [
+ [0.70710678 + 0.0j, 0.70710678 + 0.0j],
+ [0.70710678 + 0.0j, -0.70710678 + 0.0j],
+ ],
+ dtype=jnp.complex128,
),
- control=[1, 2, 3],
- )
- return qml.state()
+ wires=[0],
+ ),
+ control=[1, 2, 3],
+ )
+ return qml.state()
- print(native_controlled_unitary.mlir)
+ print(native_controlled_unitary.mlir)
test_native_controlled_unitary()
@@ -160,19 +156,19 @@ def native_controlled_unitary():
def test_native_controlled_multirz():
"""Test native control of the multirz operation."""
- with get_custom_qjit_device(3, set(), {"C(MultiRZ)"}) as dev:
-
- @qjit(target="mlir")
- @qml.qnode(dev)
- # CHECK-LABEL: public @jit_native_controlled_multirz
- def native_controlled_multirz():
- # CHECK: [[out:%.+]]:2, [[out_ctrl:%.+]] = quantum.multirz
- # CHECK-SAME: ctrls
- # CHECK-SAME: ctrlvals(%true)
- qml.ctrl(qml.MultiRZ(0.6, wires=[0, 2]), control=[1])
- return qml.state()
-
- print(native_controlled_multirz.mlir)
+ dev = get_custom_qjit_device(3, set(), {"MultiRZ": OperationProperties(True, True, True)})
+
+ @qjit(target="mlir")
+ @qml.qnode(dev)
+ # CHECK-LABEL: public @jit_native_controlled_multirz
+ def native_controlled_multirz():
+ # CHECK: [[out:%.+]]:2, [[out_ctrl:%.+]] = quantum.multirz
+ # CHECK-SAME: ctrls
+ # CHECK-SAME: ctrlvals(%true)
+ qml.ctrl(qml.MultiRZ(0.6, wires=[0, 2]), control=[1])
+ return qml.state()
+
+ print(native_controlled_multirz.mlir)
test_native_controlled_multirz()
diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py
index 1bc79225bc..50488da984 100644
--- a/frontend/test/pytest/test_config_functions.py
+++ b/frontend/test/pytest/test_config_functions.py
@@ -21,13 +21,9 @@
import pennylane as qml
import pytest
-from catalyst.device import QJITDevice
+from catalyst.device import QJITDevice, validate_device_capabilities
+from catalyst.device.qjit_device import check_no_overlap
from catalyst.utils.exceptions import CompileError
-from catalyst.utils.runtime import (
- check_no_overlap,
- get_device_capabilities,
- validate_config_with_device,
-)
from catalyst.utils.toml import (
DeviceCapabilities,
ProgramFeatures,
@@ -36,6 +32,7 @@
get_decomposable_gates,
get_matrix_decomposable_gates,
get_native_ops,
+ load_device_capabilities,
pennylane_operation_set,
read_toml_file,
)
@@ -76,33 +73,30 @@ def get_test_device_capabilities(
) -> DeviceCapabilities:
"""Parse test config into the DeviceCapabilities structure"""
config = get_test_config(config_text)
- device_capabilities = get_device_capabilities(config, program_features, "dummy")
+ device_capabilities = load_device_capabilities(config, program_features, "dummy")
return device_capabilities
@pytest.mark.parametrize("schema", ALL_SCHEMAS)
-def test_validate_config_with_device(schema):
+def test_config_qjit_incompatible_device(schema):
"""Test error is raised if checking for qjit compatibility and field is false in toml file."""
- with TemporaryDirectory() as d:
- toml_file = join(d, "test.toml")
- with open(toml_file, "w", encoding="utf-8") as f:
- f.write(
- dedent(
- f"""
- schema = {schema}
- [compilation]
- qjit_compatible = false
- """
- )
- )
+ device_capabilities = get_test_device_capabilities(
+ ProgramFeatures(False),
+ dedent(
+ f"""
+ schema = {schema}
+ [compilation]
+ qjit_compatible = false
+ """
+ ),
+ )
- config = read_toml_file(toml_file)
- name = DeviceToBeTested.name
- with pytest.raises(
- CompileError,
- match=f"Attempting to compile program for incompatible device '{name}'",
- ):
- validate_config_with_device(DeviceToBeTested(), config)
+ name = DeviceToBeTested.name
+ with pytest.raises(
+ CompileError,
+ match=f"Attempting to compile program for incompatible device '{name}'",
+ ):
+ validate_device_capabilities(DeviceToBeTested(), device_capabilities)
def test_get_observables_schema1():
@@ -384,7 +378,8 @@ def test_config_invalid_condition_duplicate(shots):
def test_config_qjit_device_operations():
"""Check the gate condition handling logic"""
- config = get_test_config(
+ capabilities = get_test_device_capabilities(
+ ProgramFeatures(False),
dedent(
r"""
schema = 2
@@ -395,7 +390,7 @@ def test_config_qjit_device_operations():
"""
),
)
- qjit_device = QJITDevice(config, shots=1000, wires=2)
+ qjit_device = QJITDevice(capabilities, shots=1000, wires=2)
assert "PauliX" in qjit_device.operations
assert "PauliY" in qjit_device.observables
diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py
index ddb385f83f..74e47bde0a 100644
--- a/frontend/test/pytest/test_custom_devices.py
+++ b/frontend/test/pytest/test_custom_devices.py
@@ -20,8 +20,9 @@
from catalyst import measure, qjit
from catalyst.compiler import get_lib_path
+from catalyst.device import extract_backend_info
from catalyst.utils.exceptions import CompileError
-from catalyst.utils.runtime import device_get_toml_config, extract_backend_info
+from catalyst.utils.toml import get_device_capabilities
# These have to match the ones in the configuration file.
OPERATIONS = [
@@ -165,8 +166,8 @@ def get_c_interface():
return "DummyDevice", lib_path
device = DummyDevice(wires=1)
- config = device_get_toml_config(device)
- backend_info = extract_backend_info(device, config)
+ capabilities = get_device_capabilities(device)
+ backend_info = extract_backend_info(device, capabilities)
assert backend_info.kwargs["option1"] == 42
assert "option2" not in backend_info.kwargs
diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py
index 371d9e4455..d9d70eb67e 100644
--- a/frontend/test/pytest/test_debug.py
+++ b/frontend/test/pytest/test_debug.py
@@ -22,7 +22,7 @@
from catalyst.compiler import CompileOptions, Compiler
from catalyst.debug import compile_from_mlir, get_cmain, print_compilation_stage
from catalyst.utils.exceptions import CompileError
-from catalyst.utils.runtime import get_lib_path
+from catalyst.utils.runtime_environment import get_lib_path
class TestDebugPrint:
diff --git a/frontend/test/pytest/test_decomposition.py b/frontend/test/pytest/test_decomposition.py
index 469c03d4a5..688e0bfc33 100644
--- a/frontend/test/pytest/test_decomposition.py
+++ b/frontend/test/pytest/test_decomposition.py
@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-import tempfile
+from copy import deepcopy
import pennylane as qml
import pytest
from jax import numpy as jnp
from catalyst import CompileError, ctrl, measure, qjit
+from catalyst.utils.toml import (
+ ProgramFeatures,
+ get_device_capabilities,
+ pennylane_operation_set,
+)
class CustomDevice(qml.QubitDevice):
@@ -32,75 +36,58 @@ class CustomDevice(qml.QubitDevice):
author = "Tester"
lightning_device = qml.device("lightning.qubit", wires=0)
- operations = lightning_device.operations.copy() - {
- "MultiControlledX",
- "Rot",
- "S",
- "C(Rot)",
- "C(S)",
- }
- observables = lightning_device.observables.copy()
-
- config = None
+
backend_name = "default"
backend_lib = "default"
backend_kwargs = {}
def __init__(self, shots=None, wires=None):
super().__init__(wires=wires, shots=shots)
- self.toml_file = None
+ program_features = ProgramFeatures(shots_present=self.shots is not None)
+ lightning_capabilities = get_device_capabilities(self.lightning_device, program_features)
+ custom_capabilities = deepcopy(lightning_capabilities)
+ custom_capabilities.native_ops.pop("Rot")
+ custom_capabilities.native_ops.pop("S")
+ custom_capabilities.to_decomp_ops.pop("MultiControlledX")
+ self.qjit_capabilities = custom_capabilities
def apply(self, operations, **kwargs):
"""Unused"""
raise RuntimeError("Only C/C++ interface is defined")
- def __enter__(self, *args, **kwargs):
- lightning_toml = self.lightning_device.config
- with open(lightning_toml, mode="r", encoding="UTF-8") as f:
- toml_contents = f.readlines()
-
- # TODO: update once schema 2 is merged
- updated_toml_contents = []
- for line in toml_contents:
- if '"MultiControlledX",' in line:
- continue
- if '"Rot",' in line:
- continue
- if '"S",' in line:
- continue
-
- updated_toml_contents.append(line)
-
- self.toml_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
- self.toml_file.writelines(updated_toml_contents)
- self.toml_file.close() # close for now without deleting
+ @property
+ def operations(self):
+ """Get PennyLane operations."""
+ return (
+ pennylane_operation_set(self.qjit_capabilities.native_ops)
+ | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops)
+ | pennylane_operation_set(self.qjit_capabilities.to_matrix_ops)
+ )
- self.config = self.toml_file.name
- return self
-
- def __exit__(self, *args, **kwargs):
- os.unlink(self.toml_file.name)
- self.config = None
+ @property
+ def observables(self):
+ """Get PennyLane observables."""
+ return pennylane_operation_set(self.qjit_capabilities.native_obs)
@pytest.mark.parametrize("param,expected", [(0.0, True), (jnp.pi, False)])
def test_decomposition(param, expected):
- with CustomDevice(wires=2) as dev:
-
- @qjit
- @qml.qnode(dev)
- def mid_circuit(x: float):
- qml.Hadamard(wires=0)
- qml.Rot(0, 0, x, wires=0)
- qml.Hadamard(wires=0)
- m = measure(wires=0)
- b = m ^ 0x1
- qml.Hadamard(wires=1)
- qml.Rot(0, 0, b * jnp.pi, wires=1)
- qml.Hadamard(wires=1)
- return measure(wires=1)
-
- assert mid_circuit(param) == expected
+ dev = CustomDevice(wires=2)
+
+ @qjit
+ @qml.qnode(dev)
+ def mid_circuit(x: float):
+ qml.Hadamard(wires=0)
+ qml.Rot(0, 0, x, wires=0)
+ qml.Hadamard(wires=0)
+ m = measure(wires=0)
+ b = m ^ 0x1
+ qml.Hadamard(wires=1)
+ qml.Rot(0, 0, b * jnp.pi, wires=1)
+ qml.Hadamard(wires=1)
+ return measure(wires=1)
+
+ assert mid_circuit(param) == expected
class TestControlledDecomposition:
diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py
index 11eee25a7a..12b8a43010 100644
--- a/frontend/test/pytest/test_device_api.py
+++ b/frontend/test/pytest/test_device_api.py
@@ -25,9 +25,9 @@
from catalyst import qjit
from catalyst.compiler import get_lib_path
-from catalyst.device import QJITDeviceNewAPI
+from catalyst.device import QJITDeviceNewAPI, extract_backend_info
from catalyst.tracing.contexts import EvaluationContext, EvaluationMode
-from catalyst.utils.runtime import device_get_toml_config, extract_backend_info
+from catalyst.utils.toml import ProgramFeatures, get_device_capabilities
class DummyDevice(Device):
@@ -87,9 +87,9 @@ def test_qjit_device():
device = DummyDevice(wires=10, shots=2032)
# Create qjit device
- config = device_get_toml_config(device)
- backend_info = extract_backend_info(device, config)
- device_qjit = QJITDeviceNewAPI(device, backend_info)
+ capabilities = get_device_capabilities(device, ProgramFeatures(device.shots is not None))
+ backend_info = extract_backend_info(device, capabilities)
+ device_qjit = QJITDeviceNewAPI(device, capabilities, backend_info)
# Check attributes of the new device
assert device_qjit.shots == qml.measurements.Shots(2032)
@@ -123,13 +123,13 @@ def test_qjit_device_no_wires():
device = DummyDeviceNoWires(shots=2032)
# Create qjit device
- config = device_get_toml_config(device)
- backend_info = extract_backend_info(device, config)
+ capabilities = get_device_capabilities(device, ProgramFeatures(device.shots is not None))
+ backend_info = extract_backend_info(device, capabilities)
with pytest.raises(
AttributeError, match="Catalyst does not support devices without set wires."
):
- QJITDeviceNewAPI(device, backend_info)
+ QJITDeviceNewAPI(device, capabilities, backend_info)
def test_simple_circuit():
diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py
index 9fc5bace89..df7bd4b822 100644
--- a/frontend/test/pytest/test_preprocess.py
+++ b/frontend/test/pytest/test_preprocess.py
@@ -412,7 +412,6 @@ def test_decomposition_of_cond_circuit(self):
@qml.qjit
@qml.qnode(dev)
def circuit(phi: float):
-
OtherHadamard(wires=0)
# define a conditional ansatz
@@ -461,7 +460,6 @@ def test_decomposition_of_forloop_circuit(self, reps, angle):
@qml.qjit
@qml.qnode(dev)
def circuit(n: int, x: float):
-
OtherHadamard(wires=0)
def loop_rx(i, phi):
@@ -495,7 +493,6 @@ def test_decomposition_of_whileloop_circuit(self, phi):
@qml.qjit
@qml.qnode(dev)
def circuit(x: float):
-
@while_loop(lambda x: x < 2.0)
def loop_rx(x):
# perform some work and update (some of) the arguments