diff --git a/doc/changelog.md b/doc/changelog.md index db3069494f..af5c2dcde7 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -92,11 +92,14 @@ annotations. [(#751)](https://github.com/PennyLaneAI/catalyst/pull/751) -* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable - class that implements the decorator's logic. This prevents having to excessively define +* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable + class that implements the decorator's logic. This prevents having to excessively define functions in a nested fashion. [(#758)](https://github.com/PennyLaneAI/catalyst/pull/758) +* Catalyst tests now manipulate device capabilities rather than text configurations files. + [(#712)](https://github.com/PennyLaneAI/catalyst/pull/712) +

Breaking changes

* Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`. @@ -198,11 +201,12 @@ This release contains contributions from (in alphabetical order): David Ittah, -Mehrdad Malekmohammadi, Erick Ochoa, +Haochen Paul Wang, Lee James O'Riordan, +Mehrdad Malekmohammadi, Raul Torres, -Haochen Paul Wang. +Sergei Mironov. # Release 0.6.0 diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index afcbb6429a..89792c0af7 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -33,7 +33,7 @@ from catalyst.utils.exceptions import CompileError from catalyst.utils.filesystem import Directory -from catalyst.utils.runtime import get_lib_path +from catalyst.utils.runtime_environment import get_lib_path package_root = os.path.dirname(__file__) diff --git a/frontend/catalyst/device/__init__.py b/frontend/catalyst/device/__init__.py index 6b9c1bd55c..a830cf7ec4 100644 --- a/frontend/catalyst/device/__init__.py +++ b/frontend/catalyst/device/__init__.py @@ -16,9 +16,18 @@ Internal API for the device module. """ -from catalyst.device.qjit_device import QJITDevice, QJITDeviceNewAPI +from catalyst.device.qjit_device import ( + BackendInfo, + QJITDevice, + QJITDeviceNewAPI, + extract_backend_info, + validate_device_capabilities, +) __all__ = ( "QJITDevice", "QJITDeviceNewAPI", + "BackendInfo", + "extract_backend_info", + "validate_device_capabilities", ) diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index c99f5db9c0..24be391cdc 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -16,10 +16,14 @@ This module contains device stubs for the old and new PennyLane device API, which facilitate the application of decomposition and other device pre-processing routines. """ - +import os +import pathlib +import platform +import re from copy import deepcopy +from dataclasses import dataclass from functools import partial -from typing import Optional, Set +from typing import Any, Dict, Optional, Set import pennylane as qml from pennylane.measurements import MidMeasureMP @@ -32,13 +36,10 @@ ) from catalyst.utils.exceptions import CompileError from catalyst.utils.patching import Patcher -from catalyst.utils.runtime import BackendInfo, device_get_toml_config +from catalyst.utils.runtime_environment import get_lib_path from catalyst.utils.toml import ( DeviceCapabilities, OperationProperties, - ProgramFeatures, - TOMLDocument, - get_device_capabilities, intersect_operations, pennylane_operation_set, ) @@ -84,6 +85,87 @@ for op in RUNTIME_OPERATIONS } +# TODO: This should be removed after implementing `get_c_interface` +# for the following backend devices: +SUPPORTED_RT_DEVICES = { + "lightning.qubit": ("LightningSimulator", "librtd_lightning"), + "lightning.kokkos": ("LightningKokkosSimulator", "librtd_lightning"), + "braket.aws.qubit": ("OpenQasmDevice", "librtd_openqasm"), + "braket.local.qubit": ("OpenQasmDevice", "librtd_openqasm"), +} + + +@dataclass +class BackendInfo: + """Backend information""" + + device_name: str + c_interface_name: str + lpath: str + kwargs: Dict[str, Any] + + +def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo: + """Extract the backend info from a quantum device. The device is expected to carry a reference + to a valid TOML config file.""" + # pylint: disable=too-many-branches + + dname = device.name + if isinstance(device, qml.Device): + dname = device.short_name + + device_name = "" + device_lpath = "" + device_kwargs = {} + + if dname in SUPPORTED_RT_DEVICES: + # Support backend devices without `get_c_interface` + device_name = SUPPORTED_RT_DEVICES[dname][0] + device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR") + sys_platform = platform.system() + + if sys_platform == "Linux": + device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so") + elif sys_platform == "Darwin": # pragma: no cover + device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib") + else: # pragma: no cover + raise NotImplementedError(f"Platform not supported: {sys_platform}") + elif hasattr(device, "get_c_interface"): + # Support third party devices with `get_c_interface` + device_name, device_lpath = device.get_c_interface() + else: + raise CompileError(f"The {dname} device does not provide C interface for compilation.") + + if not pathlib.Path(device_lpath).is_file(): + raise CompileError(f"Device at {device_lpath} cannot be found!") + + if hasattr(device, "shots"): + if isinstance(device, qml.Device): + device_kwargs["shots"] = device.shots if device.shots else 0 + else: + # TODO: support shot vectors + device_kwargs["shots"] = device.shots.total_shots if device.shots else 0 + + if dname == "braket.local.qubit": # pragma: no cover + device_kwargs["device_type"] = dname + device_kwargs["backend"] = ( + # pylint: disable=protected-access + device._device._delegate.DEVICE_ID + ) + elif dname == "braket.aws.qubit": # pragma: no cover + device_kwargs["device_type"] = dname + device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access + if device._s3_folder: # pylint: disable=protected-access + device_kwargs["s3_destination_folder"] = str( + device._s3_folder # pylint: disable=protected-access + ) + + for k, v in capabilities.options.items(): + if hasattr(device, v): + device_kwargs[k] = getattr(device, v) + + return BackendInfo(dname, device_name, device_lpath, device_kwargs) + def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Set[str]: """Calculate the set of supported quantum gates for the QJIT device from the gates @@ -165,7 +247,7 @@ def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> S def __init__( self, - target_config: TOMLDocument, + original_device_capabilities: DeviceCapabilities, shots=None, wires=None, backend: Optional[BackendInfo] = None, @@ -175,23 +257,18 @@ def __init__( self.backend_name = backend.c_interface_name if backend else "default" self.backend_lib = backend.lpath if backend else "" self.backend_kwargs = backend.kwargs if backend else {} - device_name = backend.device_name if backend else "default" - program_features = ProgramFeatures(shots is not None) - target_device_capabilities = get_device_capabilities( - target_config, program_features, device_name - ) - self.capabilities = get_qjit_device_capabilities(target_device_capabilities) + self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities) @property def operations(self) -> Set[str]: """Get the device operations using PennyLane's syntax""" - return pennylane_operation_set(self.capabilities.native_ops) + return pennylane_operation_set(self.qjit_capabilities.native_ops) @property def observables(self) -> Set[str]: """Get the device observables""" - return pennylane_operation_set(self.capabilities.native_obs) + return pennylane_operation_set(self.qjit_capabilities.native_obs) def apply(self, operations, **kwargs): """ @@ -270,6 +347,7 @@ class QJITDeviceNewAPI(qml.devices.Device): def __init__( self, original_device, + original_device_capabilities: DeviceCapabilities, backend: Optional[BackendInfo] = None, ): self.original_device = original_device @@ -285,29 +363,23 @@ def __init__( self.backend_name = backend.c_interface_name if backend else "default" self.backend_lib = backend.lpath if backend else "" self.backend_kwargs = backend.kwargs if backend else {} - device_name = backend.device_name if backend else "default" - target_config = device_get_toml_config(original_device) - program_features = ProgramFeatures(original_device.shots is not None) - target_device_capabilities = get_device_capabilities( - target_config, program_features, device_name - ) - self.capabilities = get_qjit_device_capabilities(target_device_capabilities) + self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities) @property def operations(self) -> Set[str]: """Get the device operations""" - return pennylane_operation_set(self.capabilities.native_ops) + return pennylane_operation_set(self.qjit_capabilities.native_ops) @property def observables(self) -> Set[str]: """Get the device observables""" - return pennylane_operation_set(self.capabilities.native_obs) + return pennylane_operation_set(self.qjit_capabilities.native_obs) @property def measurement_processes(self) -> Set[str]: """Get the device measurement processes""" - return self.capabilities.measurement_processes + return self.qjit_capabilities.measurement_processes def preprocess( self, @@ -334,3 +406,101 @@ def execute(self, circuits, execution_config): Raises: RuntimeError """ raise RuntimeError("QJIT devices cannot execute tapes.") + + +def filter_out_adjoint(operations): + """Remove Adjoint from operations. + + Args: + operations (List[Str]): List of strings with names of supported operations + + Returns: + List: A list of strings with names of supported operations with Adjoint and C gates + removed. + """ + adjoint = re.compile(r"^Adjoint\(.*\)$") + + def is_not_adj(op): + return not re.match(adjoint, op) + + operations_no_adj = filter(is_not_adj, operations) + return set(operations_no_adj) + + +def check_no_overlap(*args, device_name): + """Check items in *args are mutually exclusive. + + Args: + *args (List[Str]): List of strings. + device_name (str): Device name for error reporting. + + Raises: + CompileError + """ + set_of_sets = [set(arg) for arg in args] + union = set.union(*set_of_sets) + len_of_sets = [len(arg) for arg in args] + if sum(len_of_sets) == len(union): + return + + overlaps = set() + for s in set_of_sets: + overlaps.update(s - union) + union = union - s + + msg = f"Device '{device_name}' has overlapping gates: {overlaps}" + raise CompileError(msg) + + +def validate_device_capabilities( + device: qml.QubitDevice, device_capabilities: DeviceCapabilities +) -> None: + """Validate configuration document against the device attributes. + Raise CompileError in case of mismatch: + * If device is not qjit-compatible. + * If configuration file does not exists. + * If decomposable, matrix, and native gates have some overlap. + * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and + ``device.observables``. + + Args: + device (qml.Device): An instance of a quantum device. + config (TOMLDocument): A TOML document representation. + + Raises: CompileError + """ + + if not device_capabilities.qjit_compatible_flag: + raise CompileError( + f"Attempting to compile program for incompatible device '{device.name}': " + f"Config is not marked as qjit-compatible" + ) + + device_name = device.short_name if isinstance(device, qml.Device) else device.name + + native = pennylane_operation_set(device_capabilities.native_ops) + decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops) + matrix = pennylane_operation_set(device_capabilities.to_matrix_ops) + + check_no_overlap(native, decomposable, matrix, device_name=device_name) + + if hasattr(device, "operations") and hasattr(device, "observables"): + # For gates, we require strict match + device_gates = filter_out_adjoint(set(device.operations)) + spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable)) + if device_gates != spec_gates: + raise CompileError( + "Gates in qml.device.operations and specification file do not match.\n" + f"Gates that present only in the device: {device_gates - spec_gates}\n" + f"Gates that present only in spec: {spec_gates - device_gates}\n" + ) + + # For observables, we do not have `non-native` section in the config, so we check that + # device data supercedes the specification. + device_observables = set(device.observables) + spec_observables = pennylane_operation_set(device_capabilities.native_obs) + if (spec_observables - device_observables) != set(): + raise CompileError( + "Observables in qml.device.observables and specification file do not match.\n" + f"Observables that present only in spec: {spec_observables - device_observables}\n" + ) diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index b09bc9f317..045d947c44 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -22,7 +22,13 @@ from jax.core import eval_jaxpr from jax.tree_util import tree_flatten, tree_unflatten -from catalyst.device import QJITDevice, QJITDeviceNewAPI +from catalyst.device import ( + BackendInfo, + QJITDevice, + QJITDeviceNewAPI, + extract_backend_info, + validate_device_capabilities, +) from catalyst.jax_extras import ( deduce_avals, get_implicit_and_explicit_flat_args, @@ -30,13 +36,11 @@ ) from catalyst.jax_primitives import func_p from catalyst.jax_tracer import trace_quantum_function -from catalyst.utils.runtime import ( - BackendInfo, - device_get_toml_config, - extract_backend_info, - validate_config_with_device, +from catalyst.utils.toml import ( + DeviceCapabilities, + ProgramFeatures, + get_device_capabilities, ) -from catalyst.utils.toml import TOMLDocument class QFunc: @@ -54,26 +58,35 @@ def __new__(cls): raise NotImplementedError() # pragma: no-cover @staticmethod - def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> BackendInfo: + def extract_backend_info( + device: qml.QubitDevice, capabilities: DeviceCapabilities + ) -> BackendInfo: """Wrapper around extract_backend_info in the runtime module.""" - return extract_backend_info(device, config) + return extract_backend_info(device, capabilities) - # pylint: disable=no-member + # pylint: disable=no-member, attribute-defined-outside-init def __call__(self, *args, **kwargs): assert isinstance(self, qml.QNode) - config = device_get_toml_config(self.device) - validate_config_with_device(self.device, config) - backend_info = QFunc.extract_backend_info(self.device, config) + # TODO: Move the capability loading and validation to the device constructor when the + # support for old device api is dropped. + program_features = ProgramFeatures(self.device.shots is not None) + device_capabilities = get_device_capabilities(self.device, program_features) + backend_info = QFunc.extract_backend_info(self.device, device_capabilities) + + # Validate decive operations against the declared capabilities + validate_device_capabilities(self.device, device_capabilities) if isinstance(self.device, qml.devices.Device): - device = QJITDeviceNewAPI(self.device, backend_info) + qjit_device = QJITDeviceNewAPI(self.device, device_capabilities, backend_info) else: - device = QJITDevice(config, self.device.shots, self.device.wires, backend_info) + qjit_device = QJITDevice( + device_capabilities, self.device.shots, self.device.wires, backend_info + ) def _eval_quantum(*args): closed_jaxpr, out_type, out_tree = trace_quantum_function( - self.func, device, args, kwargs, qnode=self + self.func, qjit_device, args, kwargs, qnode=self ) args_expanded = get_implicit_and_explicit_flat_args(None, *args) res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded) diff --git a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py index 7201f1dfea..c53a6d0c55 100644 --- a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py +++ b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py @@ -39,6 +39,7 @@ import pennylane as qml from jax.tree_util import tree_unflatten +from catalyst.device import BackendInfo from catalyst.jax_primitives import ( AbstractObs, adjoint_p, @@ -75,7 +76,6 @@ from catalyst.qfunc import QFunc from catalyst.utils.exceptions import CompileError from catalyst.utils.patching import Patcher -from catalyst.utils.runtime import BackendInfo from .primitives import ( cuda_inst, @@ -820,7 +820,7 @@ def get_jaxpr(self, *args): an MLIR module """ - def cudaq_backend_info(device, _config) -> BackendInfo: + def cudaq_backend_info(device, _capabilities) -> BackendInfo: """The extract_backend_info should not be run by the cuda compiler as it is catalyst-specific. We need to make this API a bit nicer for third-party compilers. """ diff --git a/frontend/catalyst/utils/runtime.py b/frontend/catalyst/utils/runtime.py deleted file mode 100644 index 1fb3aed313..0000000000 --- a/frontend/catalyst/utils/runtime.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Runtime utility methods. -""" - -# pylint: disable=too-many-branches - -import os -import pathlib -import platform -import re -from dataclasses import dataclass -from typing import Any, Dict - -import pennylane as qml - -from catalyst._configuration import INSTALLED -from catalyst.utils.exceptions import CompileError -from catalyst.utils.toml import ( - ProgramFeatures, - TOMLDocument, - get_device_capabilities, - pennylane_operation_set, - read_toml_file, -) - -package_root = os.path.dirname(__file__) - - -# Default paths to dep libraries -DEFAULT_LIB_PATHS = { - "llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"), - "runtime": os.path.join(package_root, "../../../runtime/build/lib"), - "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"), - "oqc_runtime": os.path.join(package_root, "../../catalyst/third_party/oqc/src/build"), -} - - -# TODO: This should be removed after implementing `get_c_interface` -# for the following backend devices: -SUPPORTED_RT_DEVICES = { - "lightning.qubit": ("LightningSimulator", "librtd_lightning"), - "lightning.kokkos": ("LightningKokkosSimulator", "librtd_lightning"), - "braket.aws.qubit": ("OpenQasmDevice", "librtd_openqasm"), - "braket.local.qubit": ("OpenQasmDevice", "librtd_openqasm"), -} - - -def get_lib_path(project, env_var): - """Get the library path.""" - if INSTALLED: - return os.path.join(package_root, "..", "lib") # pragma: no cover - return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, "")) - - -def check_no_overlap(*args, device_name): - """Check items in *args are mutually exclusive. - - Args: - *args (List[Str]): List of strings. - device_name (str): Device name for error reporting. - - Raises: - CompileError - """ - set_of_sets = [set(arg) for arg in args] - union = set.union(*set_of_sets) - len_of_sets = [len(arg) for arg in args] - if sum(len_of_sets) == len(union): - return - - overlaps = set() - for s in set_of_sets: - overlaps.update(s - union) - union = union - s - - msg = f"Device '{device_name}' has overlapping gates: {overlaps}" - raise CompileError(msg) - - -def filter_out_adjoint(operations): - """Remove Adjoint from operations. - - Args: - operations (List[Str]): List of strings with names of supported operations - - Returns: - List: A list of strings with names of supported operations with Adjoint and C gates - removed. - """ - adjoint = re.compile(r"^Adjoint\(.*\)$") - - def is_not_adj(op): - return not re.match(adjoint, op) - - operations_no_adj = filter(is_not_adj, operations) - return set(operations_no_adj) - - -def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) -> None: - """Validate configuration document against the device attributes. - Raise CompileError in case of mismatch: - * If device is not qjit-compatible. - * If configuration file does not exists. - * If decomposable, matrix, and native gates have some overlap. - * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and - ``device.observables``. - - Args: - device (qml.Device): An instance of a quantum device. - config (TOMLDocument): A TOML document representation. - - Raises: CompileError - """ - - if not config["compilation"]["qjit_compatible"]: - raise CompileError( - f"Attempting to compile program for incompatible device '{device.name}': " - f"Config is not marked as qjit-compatible" - ) - - device_name = device.short_name if isinstance(device, qml.Device) else device.name - program_features = ProgramFeatures(device.shots is not None) - device_capabilities = get_device_capabilities(config, program_features, device_name) - - native = pennylane_operation_set(device_capabilities.native_ops) - decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops) - matrix = pennylane_operation_set(device_capabilities.to_matrix_ops) - - check_no_overlap(native, decomposable, matrix, device_name=device_name) - - if hasattr(device, "operations") and hasattr(device, "observables"): - # For gates, we require strict match - device_gates = filter_out_adjoint(set(device.operations)) - spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable)) - if device_gates != spec_gates: - raise CompileError( - "Gates in qml.device.operations and specification file do not match.\n" - f"Gates that present only in the device: {device_gates - spec_gates}\n" - f"Gates that present only in spec: {spec_gates - device_gates}\n" - ) - - # For observables, we do not have `non-native` section in the config, so we check that - # device data supercedes the specification. - device_observables = set(device.observables) - spec_observables = pennylane_operation_set(device_capabilities.native_obs) - if (spec_observables - device_observables) != set(): - raise CompileError( - "Observables in qml.device.observables and specification file do not match.\n" - f"Observables that present only in spec: {spec_observables - device_observables}\n" - ) - - -def device_get_toml_config(device) -> TOMLDocument: - """Get the contents of the device config file.""" - if hasattr(device, "config"): - # The expected case: device specifies its own config. - toml_file = device.config - else: - # TODO: Remove this section when `qml.Device`s are guaranteed to have their own config file - # field. - device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR")) - - name = device.short_name if isinstance(device, qml.Device) else device.name - # The toml files name convention we follow is to replace - # the dots with underscores in the device short name. - toml_file_name = name.replace(".", "_") + ".toml" - # And they are currently saved in the following directory. - toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name - - try: - config = read_toml_file(toml_file) - except FileNotFoundError as e: - raise CompileError( - "Attempting to compile program for incompatible device: " - f"Config file ({toml_file}) does not exist" - ) from e - - return config - - -@dataclass -class BackendInfo: - """Backend information""" - - device_name: str - c_interface_name: str - lpath: str - kwargs: Dict[str, Any] - - -def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> BackendInfo: - """Extract the backend info from a quantum device. The device is expected to carry a reference - to a valid TOML config file.""" - - dname = device.name - if isinstance(device, qml.Device): - dname = device.short_name - - device_name = "" - device_lpath = "" - device_kwargs = {} - - if dname in SUPPORTED_RT_DEVICES: - # Support backend devices without `get_c_interface` - device_name = SUPPORTED_RT_DEVICES[dname][0] - device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR") - sys_platform = platform.system() - - if sys_platform == "Linux": - device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so") - elif sys_platform == "Darwin": # pragma: no cover - device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib") - else: # pragma: no cover - raise NotImplementedError(f"Platform not supported: {sys_platform}") - elif hasattr(device, "get_c_interface"): - # Support third party devices with `get_c_interface` - device_name, device_lpath = device.get_c_interface() - else: - raise CompileError(f"The {dname} device does not provide C interface for compilation.") - - if not pathlib.Path(device_lpath).is_file(): - raise CompileError(f"Device at {device_lpath} cannot be found!") - - if hasattr(device, "shots"): - if isinstance(device, qml.Device): - device_kwargs["shots"] = device.shots if device.shots else 0 - else: - # TODO: support shot vectors - device_kwargs["shots"] = device.shots.total_shots if device.shots else 0 - - if dname == "braket.local.qubit": # pragma: no cover - device_kwargs["device_type"] = dname - device_kwargs["backend"] = ( - # pylint: disable=protected-access - device._device._delegate.DEVICE_ID - ) - elif dname == "braket.aws.qubit": # pragma: no cover - device_kwargs["device_type"] = dname - device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access - if device._s3_folder: # pylint: disable=protected-access - device_kwargs["s3_destination_folder"] = str( - device._s3_folder # pylint: disable=protected-access - ) - - options = config.get("options", {}) - for k, v in options.items(): - if hasattr(device, v): - device_kwargs[k] = getattr(device, v) - - return BackendInfo(dname, device_name, device_lpath, device_kwargs) diff --git a/frontend/catalyst/utils/runtime_environment.py b/frontend/catalyst/utils/runtime_environment.py new file mode 100644 index 0000000000..a75b466809 --- /dev/null +++ b/frontend/catalyst/utils/runtime_environment.py @@ -0,0 +1,38 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility code for keeping paths +""" +import os +import os.path + +from catalyst._configuration import INSTALLED + +package_root = os.path.dirname(__file__) + +# Default paths to dep libraries +DEFAULT_LIB_PATHS = { + "llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"), + "runtime": os.path.join(package_root, "../../../runtime/build/lib"), + "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"), + "oqc_runtime": os.path.join(package_root, "../../catalyst/third_party/oqc/src/build"), +} + + +def get_lib_path(project, env_var): + """Get the library path.""" + if INSTALLED: + return os.path.join(package_root, "..", "lib") # pragma: no cover + return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, "")) diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py index 8044962889..1806d23c2c 100644 --- a/frontend/catalyst/utils/toml.py +++ b/frontend/catalyst/utils/toml.py @@ -16,12 +16,16 @@ """ import importlib.util +import pathlib from dataclasses import dataclass from functools import reduce from itertools import repeat -from typing import Any, Dict, List, Set +from typing import Any, Dict, List, Optional, Set + +import pennylane as qml from catalyst.utils.exceptions import CompileError +from catalyst.utils.runtime_environment import get_lib_path # TODO: # Once Python version 3.11 is the oldest supported Python version, we can remove tomlkit @@ -33,8 +37,7 @@ tomlkit = importlib.util.find_spec("tomlkit") # We need at least one of these to make sure we can read toml files. if tomllib is None and tomlkit is None: # pragma: nocover - msg = "Either tomllib or tomlkit need to be installed." - raise ImportError(msg) + raise ImportError("Either tomllib or tomlkit need to be installed.") # Give preference to tomllib if tomllib: # pragma: nocover @@ -82,9 +85,11 @@ class DeviceCapabilities: # pylint: disable=too-many-instance-attributes to_matrix_ops: Dict[str, OperationProperties] native_obs: Dict[str, OperationProperties] measurement_processes: Set[str] + qjit_compatible_flag: bool mid_circuit_measurement_flag: bool runtime_code_generation_flag: bool dynamic_qubit_management_flag: bool + options: Dict[str, bool] def intersect_operations( @@ -112,11 +117,16 @@ class ProgramFeatures: shots_present: bool -def check_compilation_flag(config: TOMLDocument, flag_name: str) -> bool: - """Checks the flag in the toml document 'compilation' section.""" +def get_compilation_flag(config: TOMLDocument, flag_name: str) -> bool: + """Get the flag in the toml document 'compilation' section.""" return bool(config.get("compilation", {}).get(flag_name, False)) +def get_options(config: TOMLDocument) -> Dict[str, str]: + """Get custom options sections""" + return {str(k): str(v) for k, v in config.get("options", {}).items()} + + def check_quantum_control_flag(config: TOMLDocument) -> bool: """Check the control flag. Only exists in toml config schema 1""" schema = int(config["schema"]) @@ -300,7 +310,7 @@ def patch_schema1_collections( for op, props in native_gate_props.items(): props.controllable = op not in gates_to_be_decomposed_if_controlled - supports_adjoint = check_compilation_flag(config, "quantum_adjoint") + supports_adjoint = get_compilation_flag(config, "quantum_adjoint") if supports_adjoint: # Makr all gates as invertibles for props in native_gate_props.values(): @@ -326,10 +336,54 @@ def patch_schema1_collections( decomp_props.pop("ControlledPhaseShift") +def get_device_toml_config(device) -> TOMLDocument: + """Get the contents of the device config file.""" + if hasattr(device, "config"): + # The expected case: device specifies its own config. + toml_file = device.config + else: + # TODO: Remove this section when `qml.Device`s are guaranteed to have their own config file + # field. + device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR")) + + name = device.short_name if isinstance(device, qml.Device) else device.name + # The toml files name convention we follow is to replace + # the dots with underscores in the device short name. + toml_file_name = name.replace(".", "_") + ".toml" + # And they are currently saved in the following directory. + toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name + + try: + config = read_toml_file(toml_file) + except FileNotFoundError as e: + raise CompileError( + "Attempting to compile program for incompatible device: " + f"Config file ({toml_file}) does not exist" + ) from e + + return config + + def get_device_capabilities( + device, program_features: Optional[ProgramFeatures] = None +) -> DeviceCapabilities: + """Get or load DeviceCapabilities structure from device""" + + if hasattr(device, "qjit_capabilities"): + return device.qjit_capabilities + else: + program_features = ( + program_features if program_features else ProgramFeatures(device.shots is not None) + ) + device_name = device.short_name if isinstance(device, qml.Device) else device.name + device_config = get_device_toml_config(device) + return load_device_capabilities(device_config, program_features, device_name) + + +def load_device_capabilities( config: TOMLDocument, program_features: ProgramFeatures, device_name: str ) -> DeviceCapabilities: - """Load TOML document into the DeviceCapabilities structure""" + """Load device capabilities from device config""" schema = int(config["schema"]) @@ -369,7 +423,9 @@ def get_device_capabilities( to_matrix_ops=matrix_decomp_props, native_obs=observable_props, measurement_processes=measurements_props, - mid_circuit_measurement_flag=check_compilation_flag(config, "mid_circuit_measurement"), - runtime_code_generation_flag=check_compilation_flag(config, "runtime_code_generation"), - dynamic_qubit_management_flag=check_compilation_flag(config, "dynamic_qubit_management"), + qjit_compatible_flag=get_compilation_flag(config, "qjit_compatible"), + mid_circuit_measurement_flag=get_compilation_flag(config, "mid_circuit_measurement"), + runtime_code_generation_flag=get_compilation_flag(config, "runtime_code_generation"), + dynamic_qubit_management_flag=get_compilation_flag(config, "dynamic_qubit_management"), + options=get_options(config), ) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index c0375a4356..2c1c0ce57b 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -15,13 +15,17 @@ # RUN: %PYTHON %s | FileCheck %s # pylint: disable=line-too-long -import os -import tempfile +from copy import deepcopy import jax import pennylane as qml from catalyst import cond, for_loop, measure, qjit, while_loop +from catalyst.utils.toml import ( + ProgramFeatures, + get_device_capabilities, + pennylane_operation_set, +) def get_custom_device_without(num_wires, discards): @@ -37,8 +41,6 @@ class CustomDevice(qml.QubitDevice): author = "Tester" lightning_device = qml.device("lightning.qubit", wires=0) - operations = lightning_device.operations.copy() - discards - observables = lightning_device.observables.copy() config = None backend_name = "default" @@ -47,54 +49,57 @@ class CustomDevice(qml.QubitDevice): def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) - self.toml_file = None + program_features = ProgramFeatures(shots_present=self.shots is not None) + lightning_capabilities = get_device_capabilities( + self.lightning_device, program_features + ) + custom_capabilities = deepcopy(lightning_capabilities) + for gate in discards: + if gate in custom_capabilities.native_ops: + custom_capabilities.native_ops.pop(gate) + if gate in custom_capabilities.to_decomp_ops: + custom_capabilities.to_decomp_ops.pop(gate) + if gate in custom_capabilities.to_matrix_ops: + custom_capabilities.to_matrix_ops.pop(gate) + self.qjit_capabilities = custom_capabilities def apply(self, operations, **kwargs): """Unused""" raise RuntimeError("Only C/C++ interface is defined") - def __enter__(self, *args, **kwargs): - lightning_toml = self.lightning_device.config - with open(lightning_toml, mode="r", encoding="UTF-8") as f: - toml_contents = f.readlines() + @property + def operations(self): + """Return operations using PennyLane's C(.) syntax""" + return ( + pennylane_operation_set(self.qjit_capabilities.native_ops) + | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops) + | pennylane_operation_set(self.qjit_capabilities.to_matrix_ops) + ) - # TODO: update once schema 2 is merged - updated_toml_contents = [] - for line in toml_contents: - if any(f'"{gate}",' in line for gate in discards): - continue - updated_toml_contents.append(line) - - self.toml_file = tempfile.NamedTemporaryFile(mode="w", delete=False) - self.toml_file.writelines(updated_toml_contents) - self.toml_file.close() # close for now without deleting - - self.config = self.toml_file.name - return self - - def __exit__(self, *args, **kwargs): - os.unlink(self.toml_file.name) - self.config = None + @property + def observables(self): + """Return PennyLane observables""" + return pennylane_operation_set(self.qjit_capabilities.native_obs) return CustomDevice(wires=num_wires) def test_decompose_multicontrolledx(): """Test decomposition of MultiControlledX.""" - with get_custom_device_without(5, {"MultiControlledX"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_multicontrolled_x1 - def decompose_multicontrolled_x1(theta: float): - qml.RX(theta, wires=[0]) - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - return qml.state() + dev = get_custom_device_without(5, {"MultiControlledX"}) - print(decompose_multicontrolled_x1.mlir) + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_multicontrolled_x1 + def decompose_multicontrolled_x1(theta: float): + qml.RX(theta, wires=[0]) + # CHECK-NOT: name = "MultiControlledX" + # CHECK: quantum.unitary + # CHECK-NOT: name = "MultiControlledX" + qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) + return qml.state() + + print(decompose_multicontrolled_x1.mlir) test_decompose_multicontrolledx() @@ -102,25 +107,25 @@ def decompose_multicontrolled_x1(theta: float): def test_decompose_multicontrolledx_in_conditional(): """Test decomposition of MultiControlledX in conditional.""" - with get_custom_device_without(5, {"MultiControlledX"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: @jit_decompose_multicontrolled_x2 - def decompose_multicontrolled_x2(theta: float, n: int): - qml.RX(theta, wires=[0]) - - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - @cond(n > 1) - def cond_fn(): - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) + dev = get_custom_device_without(5, {"MultiControlledX"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: @jit_decompose_multicontrolled_x2 + def decompose_multicontrolled_x2(theta: float, n: int): + qml.RX(theta, wires=[0]) + + # CHECK-NOT: name = "MultiControlledX" + # CHECK: quantum.unitary + # CHECK-NOT: name = "MultiControlledX" + @cond(n > 1) + def cond_fn(): + qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - cond_fn() - return qml.state() + cond_fn() + return qml.state() - print(decompose_multicontrolled_x2.mlir) + print(decompose_multicontrolled_x2.mlir) test_decompose_multicontrolledx_in_conditional() @@ -128,26 +133,26 @@ def cond_fn(): def test_decompose_multicontrolledx_in_while_loop(): """Test decomposition of MultiControlledX in while loop.""" - with get_custom_device_without(5, {"MultiControlledX"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: @jit_decompose_multicontrolled_x3 - def decompose_multicontrolled_x3(theta: float, n: int): - qml.RX(theta, wires=[0]) - - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - @while_loop(lambda v: v[0] < 10) - def loop(v): - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - return v[0] + 1, v[1] + dev = get_custom_device_without(5, {"MultiControlledX"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: @jit_decompose_multicontrolled_x3 + def decompose_multicontrolled_x3(theta: float, n: int): + qml.RX(theta, wires=[0]) + + # CHECK-NOT: name = "MultiControlledX" + # CHECK: quantum.unitary + # CHECK-NOT: name = "MultiControlledX" + @while_loop(lambda v: v[0] < 10) + def loop(v): + qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) + return v[0] + 1, v[1] - loop((0, n)) - return qml.state() + loop((0, n)) + return qml.state() - print(decompose_multicontrolled_x3.mlir) + print(decompose_multicontrolled_x3.mlir) test_decompose_multicontrolledx_in_while_loop() @@ -155,25 +160,25 @@ def loop(v): def test_decompose_multicontrolledx_in_for_loop(): """Test decomposition of MultiControlledX in for loop.""" - with get_custom_device_without(5, {"MultiControlledX"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: @jit_decompose_multicontrolled_x4 - def decompose_multicontrolled_x4(theta: float, n: int): - qml.RX(theta, wires=[0]) - - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - @for_loop(0, n, 1) - def loop(_): - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) + dev = get_custom_device_without(5, {"MultiControlledX"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: @jit_decompose_multicontrolled_x4 + def decompose_multicontrolled_x4(theta: float, n: int): + qml.RX(theta, wires=[0]) + + # CHECK-NOT: name = "MultiControlledX" + # CHECK: quantum.unitary + # CHECK-NOT: name = "MultiControlledX" + @for_loop(0, n, 1) + def loop(_): + qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - loop() - return qml.state() + loop() + return qml.state() - print(decompose_multicontrolled_x4.mlir) + print(decompose_multicontrolled_x4.mlir) test_decompose_multicontrolledx_in_for_loop() @@ -181,29 +186,29 @@ def loop(_): def test_decompose_rot(): """Test decomposition of Rot gate.""" - with get_custom_device_without(1, {"Rot", "C(Rot)"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_rot - def decompose_rot(phi: float, theta: float, omega: float): - # CHECK-NOT: name = "Rot" - # CHECK: [[phi:%.+]] = tensor.extract %arg0 - # CHECK-NOT: name = "Rot" - # CHECK: {{%.+}} = quantum.custom "RZ"([[phi]]) - # CHECK-NOT: name = "Rot" - # CHECK: [[theta:%.+]] = tensor.extract %arg1 - # CHECK-NOT: name = "Rot" - # CHECK: {{%.+}} = quantum.custom "RY"([[theta]]) - # CHECK-NOT: name = "Rot" - # CHECK: [[omega:%.+]] = tensor.extract %arg2 - # CHECK-NOT: name = "Rot" - # CHECK: {{%.+}} = quantum.custom "RZ"([[omega]]) - # CHECK-NOT: name = "Rot" - qml.Rot(phi, theta, omega, wires=0) - return measure(wires=0) - - print(decompose_rot.mlir) + dev = get_custom_device_without(1, {"Rot", "C(Rot)"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_rot + def decompose_rot(phi: float, theta: float, omega: float): + # CHECK-NOT: name = "Rot" + # CHECK: [[phi:%.+]] = tensor.extract %arg0 + # CHECK-NOT: name = "Rot" + # CHECK: {{%.+}} = quantum.custom "RZ"([[phi]]) + # CHECK-NOT: name = "Rot" + # CHECK: [[theta:%.+]] = tensor.extract %arg1 + # CHECK-NOT: name = "Rot" + # CHECK: {{%.+}} = quantum.custom "RY"([[theta]]) + # CHECK-NOT: name = "Rot" + # CHECK: [[omega:%.+]] = tensor.extract %arg2 + # CHECK-NOT: name = "Rot" + # CHECK: {{%.+}} = quantum.custom "RZ"([[omega]]) + # CHECK-NOT: name = "Rot" + qml.Rot(phi, theta, omega, wires=0) + return measure(wires=0) + + print(decompose_rot.mlir) test_decompose_rot() @@ -211,21 +216,21 @@ def decompose_rot(phi: float, theta: float, omega: float): def test_decompose_s(): """Test decomposition of S gate.""" - with get_custom_device_without(1, {"S", "C(S)"}) as dev: + dev = get_custom_device_without(1, {"S", "C(S)"}) - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_s - def decompose_s(): - # CHECK-NOT: name="S" - # CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64 - # CHECK-NOT: name = "S" - # CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]]) - # CHECK-NOT: name = "S" - qml.S(wires=0) - return measure(wires=0) + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_s + def decompose_s(): + # CHECK-NOT: name="S" + # CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64 + # CHECK-NOT: name = "S" + # CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]]) + # CHECK-NOT: name = "S" + qml.S(wires=0) + return measure(wires=0) - print(decompose_s.mlir) + print(decompose_s.mlir) test_decompose_s() @@ -233,21 +238,21 @@ def decompose_s(): def test_decompose_qubitunitary(): """Test decomposition of QubitUnitary""" - with get_custom_device_without(1, {"QubitUnitary"}) as dev: + dev = get_custom_device_without(1, {"QubitUnitary"}) - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_qubit_unitary - def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)): - # CHECK-NOT: name = "QubitUnitary" - # CHECK: quantum.custom "RZ" - # CHECK: quantum.custom "RY" - # CHECK: quantum.custom "RZ" - # CHECK-NOT: name = "QubitUnitary" - qml.QubitUnitary(U, wires=0) - return measure(wires=0) + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_qubit_unitary + def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)): + # CHECK-NOT: name = "QubitUnitary" + # CHECK: quantum.custom "RZ" + # CHECK: quantum.custom "RY" + # CHECK: quantum.custom "RZ" + # CHECK-NOT: name = "QubitUnitary" + qml.QubitUnitary(U, wires=0) + return measure(wires=0) - print(decompose_qubit_unitary.mlir) + print(decompose_qubit_unitary.mlir) test_decompose_qubitunitary() @@ -255,47 +260,47 @@ def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)): def test_decompose_singleexcitationplus(): """Test decomposition of single excitation plus.""" - with get_custom_device_without(2, {"SingleExcitationPlus", "C(SingleExcitationPlus)"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_singleexcitationplus - def decompose_singleexcitationplus(theta: float): - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[a_scalar_tensor_float_2:%.+]] = stablehlo.constant dense<2.{{[0]+}}e+00> - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s0q1:%.+]] = quantum.custom "PauliX" - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s0q0:%.+]] = quantum.custom "PauliX" - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[a_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[a_theta_div_2_scalar:%.+]] = tensor.extract [[a_theta_div_2]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s1:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[a_theta_div_2_scalar]]) [[s0q1]], [[s0q0]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s2q1:%.+]] = quantum.custom "PauliX"() [[s1]]#1 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s2q0:%.+]] = quantum.custom "PauliX"() [[s1]]#0 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[b_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[b_theta_div_2_scalar:%.+]] = tensor.extract [[b_theta_div_2]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s3:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[b_theta_div_2_scalar]]) [[s2q1]], [[s2q0]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s4:%.+]]:2 = quantum.custom "CNOT"() [[s3]]#0, [[s3]]#1 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[theta_scalar:%.+]] = tensor.extract %arg0 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s5:%.+]]:2 = quantum.custom "CRY"([[theta_scalar]]) [[s4]]#1, [[s4]]#0 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s6:%.+]]:2 = quantum.custom "CNOT"() [[s5]]#1, [[s5]]#0 - # CHECK-NOT: name = "SingleExcitationPlus" - qml.SingleExcitationPlus(theta, wires=[0, 1]) - return measure(wires=0) - - print(decompose_singleexcitationplus.mlir) + dev = get_custom_device_without(2, {"SingleExcitationPlus", "C(SingleExcitationPlus)"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_singleexcitationplus + def decompose_singleexcitationplus(theta: float): + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[a_scalar_tensor_float_2:%.+]] = stablehlo.constant dense<2.{{[0]+}}e+00> + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s0q1:%.+]] = quantum.custom "PauliX" + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s0q0:%.+]] = quantum.custom "PauliX" + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[a_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[a_theta_div_2_scalar:%.+]] = tensor.extract [[a_theta_div_2]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s1:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[a_theta_div_2_scalar]]) [[s0q1]], [[s0q0]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s2q1:%.+]] = quantum.custom "PauliX"() [[s1]]#1 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s2q0:%.+]] = quantum.custom "PauliX"() [[s1]]#0 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[b_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[b_theta_div_2_scalar:%.+]] = tensor.extract [[b_theta_div_2]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s3:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[b_theta_div_2_scalar]]) [[s2q1]], [[s2q0]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s4:%.+]]:2 = quantum.custom "CNOT"() [[s3]]#0, [[s3]]#1 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[theta_scalar:%.+]] = tensor.extract %arg0 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s5:%.+]]:2 = quantum.custom "CRY"([[theta_scalar]]) [[s4]]#1, [[s4]]#0 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s6:%.+]]:2 = quantum.custom "CNOT"() [[s5]]#1, [[s5]]#0 + # CHECK-NOT: name = "SingleExcitationPlus" + qml.SingleExcitationPlus(theta, wires=[0, 1]) + return measure(wires=0) + + print(decompose_singleexcitationplus.mlir) test_decompose_singleexcitationplus() diff --git a/frontend/test/lit/test_quantum_control.py b/frontend/test/lit/test_quantum_control.py index e7125e7eb7..15cc876f93 100644 --- a/frontend/test/lit/test_quantum_control.py +++ b/frontend/test/lit/test_quantum_control.py @@ -15,13 +15,18 @@ # RUN: %PYTHON %s | FileCheck %s """ Test the lowering cases involving quantum control """ -import os -import tempfile +from copy import deepcopy import jax.numpy as jnp import pennylane as qml from catalyst import qjit +from catalyst.utils.toml import ( + OperationProperties, + ProgramFeatures, + get_device_capabilities, + pennylane_operation_set, +) def get_custom_qjit_device(num_wires, discards, additions): @@ -37,70 +42,61 @@ class CustomDevice(qml.QubitDevice): author = "Tester" lightning_device = qml.device("lightning.qubit", wires=0) - operations = lightning_device.operations.copy() - discards | additions - observables = lightning_device.observables.copy() - config = None backend_name = "default" backend_lib = "default" backend_kwargs = {} def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) - self.toml_file = None + program_features = ProgramFeatures(shots_present=shots is not None) + lightning_capabilities = get_device_capabilities( + self.lightning_device, program_features + ) + custom_capabilities = deepcopy(lightning_capabilities) + for gate in discards: + custom_capabilities.native_ops.pop(gate) + custom_capabilities.native_ops.update(additions) + self.qjit_capabilities = custom_capabilities + + @property + def operations(self): + """Get PennyLane operations.""" + return ( + pennylane_operation_set(self.qjit_capabilities.native_ops) + | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops) + | pennylane_operation_set(self.qjit_capabilities.to_matrix_ops) + ) + + @property + def observables(self): + """Get PennyLane observables.""" + return pennylane_operation_set(self.qjit_capabilities.native_obs) def apply(self, operations, **kwargs): """Unused""" raise RuntimeError("Only C/C++ interface is defined") - def __enter__(self, *args, **kwargs): - lightning_toml = self.lightning_device.config - with open(lightning_toml, mode="r", encoding="UTF-8") as f: - toml_contents = f.readlines() - - # TODO: update once schema 2 is merged - updated_toml_contents = [] - for line in toml_contents: - if any(f'"{gate}",' in line for gate in discards): - continue - - updated_toml_contents.append(line) - if "native = [" in line: - for gate in additions: - if not gate.startswith("C("): - updated_toml_contents.append(f' "{gate}",\n') - - self.toml_file = tempfile.NamedTemporaryFile(mode="w", delete=False) - self.toml_file.writelines(updated_toml_contents) - self.toml_file.close() # close for now without deleting - - self.config = self.toml_file.name - return self - - def __exit__(self, *args, **kwargs): - os.unlink(self.toml_file.name) - self.config = None - return CustomDevice(wires=num_wires) def test_named_controlled(): """Test that named-controlled operations are passed as-is.""" - with get_custom_qjit_device(2, set(), set()) as dev: + dev = get_custom_qjit_device(2, set(), set()) - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_named_controlled - def named_controlled(): - # CHECK: quantum.custom "CNOT" - qml.CNOT(wires=[0, 1]) - # CHECK: quantum.custom "CY" - qml.CY(wires=[0, 1]) - # CHECK: quantum.custom "CZ" - qml.CZ(wires=[0, 1]) - return qml.state() + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_named_controlled + def named_controlled(): + # CHECK: quantum.custom "CNOT" + qml.CNOT(wires=[0, 1]) + # CHECK: quantum.custom "CY" + qml.CY(wires=[0, 1]) + # CHECK: quantum.custom "CZ" + qml.CZ(wires=[0, 1]) + return qml.state() - print(named_controlled.mlir) + print(named_controlled.mlir) test_named_controlled() @@ -108,19 +104,19 @@ def named_controlled(): def test_native_controlled_custom(): """Test native control of a custom operation.""" - with get_custom_qjit_device(3, {"CRot"}, {"Rot", "C(Rot)"}) as dev: + dev = get_custom_qjit_device(3, set(), {"Rot": OperationProperties(True, True, False)}) - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_native_controlled - def native_controlled(): - # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot" - # CHECK-SAME: ctrls - # CHECK-SAME: ctrlvals(%true, %true) - qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2]) - return qml.state() + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_native_controlled + def native_controlled(): + # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot" + # CHECK-SAME: ctrls + # CHECK-SAME: ctrlvals(%true, %true) + qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2]) + return qml.state() - print(native_controlled.mlir) + print(native_controlled.mlir) test_native_controlled_custom() @@ -128,31 +124,31 @@ def native_controlled(): def test_native_controlled_unitary(): """Test native control of the unitary operation.""" - with get_custom_qjit_device(4, set(), set()) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_native_controlled_unitary - def native_controlled_unitary(): - # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:3 = quantum.unitary - # CHECK-SAME: ctrls - # CHECK-SAME: ctrlvals(%true, %true, %true) - qml.ctrl( - qml.QubitUnitary( - jnp.array( - [ - [0.70710678 + 0.0j, 0.70710678 + 0.0j], - [0.70710678 + 0.0j, -0.70710678 + 0.0j], - ], - dtype=jnp.complex128, - ), - wires=[0], + dev = get_custom_qjit_device(4, set(), set()) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_native_controlled_unitary + def native_controlled_unitary(): + # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:3 = quantum.unitary + # CHECK-SAME: ctrls + # CHECK-SAME: ctrlvals(%true, %true, %true) + qml.ctrl( + qml.QubitUnitary( + jnp.array( + [ + [0.70710678 + 0.0j, 0.70710678 + 0.0j], + [0.70710678 + 0.0j, -0.70710678 + 0.0j], + ], + dtype=jnp.complex128, ), - control=[1, 2, 3], - ) - return qml.state() + wires=[0], + ), + control=[1, 2, 3], + ) + return qml.state() - print(native_controlled_unitary.mlir) + print(native_controlled_unitary.mlir) test_native_controlled_unitary() @@ -160,19 +156,19 @@ def native_controlled_unitary(): def test_native_controlled_multirz(): """Test native control of the multirz operation.""" - with get_custom_qjit_device(3, set(), {"C(MultiRZ)"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_native_controlled_multirz - def native_controlled_multirz(): - # CHECK: [[out:%.+]]:2, [[out_ctrl:%.+]] = quantum.multirz - # CHECK-SAME: ctrls - # CHECK-SAME: ctrlvals(%true) - qml.ctrl(qml.MultiRZ(0.6, wires=[0, 2]), control=[1]) - return qml.state() - - print(native_controlled_multirz.mlir) + dev = get_custom_qjit_device(3, set(), {"MultiRZ": OperationProperties(True, True, True)}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_native_controlled_multirz + def native_controlled_multirz(): + # CHECK: [[out:%.+]]:2, [[out_ctrl:%.+]] = quantum.multirz + # CHECK-SAME: ctrls + # CHECK-SAME: ctrlvals(%true) + qml.ctrl(qml.MultiRZ(0.6, wires=[0, 2]), control=[1]) + return qml.state() + + print(native_controlled_multirz.mlir) test_native_controlled_multirz() diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py index 1bc79225bc..50488da984 100644 --- a/frontend/test/pytest/test_config_functions.py +++ b/frontend/test/pytest/test_config_functions.py @@ -21,13 +21,9 @@ import pennylane as qml import pytest -from catalyst.device import QJITDevice +from catalyst.device import QJITDevice, validate_device_capabilities +from catalyst.device.qjit_device import check_no_overlap from catalyst.utils.exceptions import CompileError -from catalyst.utils.runtime import ( - check_no_overlap, - get_device_capabilities, - validate_config_with_device, -) from catalyst.utils.toml import ( DeviceCapabilities, ProgramFeatures, @@ -36,6 +32,7 @@ get_decomposable_gates, get_matrix_decomposable_gates, get_native_ops, + load_device_capabilities, pennylane_operation_set, read_toml_file, ) @@ -76,33 +73,30 @@ def get_test_device_capabilities( ) -> DeviceCapabilities: """Parse test config into the DeviceCapabilities structure""" config = get_test_config(config_text) - device_capabilities = get_device_capabilities(config, program_features, "dummy") + device_capabilities = load_device_capabilities(config, program_features, "dummy") return device_capabilities @pytest.mark.parametrize("schema", ALL_SCHEMAS) -def test_validate_config_with_device(schema): +def test_config_qjit_incompatible_device(schema): """Test error is raised if checking for qjit compatibility and field is false in toml file.""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - f""" - schema = {schema} - [compilation] - qjit_compatible = false - """ - ) - ) + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + f""" + schema = {schema} + [compilation] + qjit_compatible = false + """ + ), + ) - config = read_toml_file(toml_file) - name = DeviceToBeTested.name - with pytest.raises( - CompileError, - match=f"Attempting to compile program for incompatible device '{name}'", - ): - validate_config_with_device(DeviceToBeTested(), config) + name = DeviceToBeTested.name + with pytest.raises( + CompileError, + match=f"Attempting to compile program for incompatible device '{name}'", + ): + validate_device_capabilities(DeviceToBeTested(), device_capabilities) def test_get_observables_schema1(): @@ -384,7 +378,8 @@ def test_config_invalid_condition_duplicate(shots): def test_config_qjit_device_operations(): """Check the gate condition handling logic""" - config = get_test_config( + capabilities = get_test_device_capabilities( + ProgramFeatures(False), dedent( r""" schema = 2 @@ -395,7 +390,7 @@ def test_config_qjit_device_operations(): """ ), ) - qjit_device = QJITDevice(config, shots=1000, wires=2) + qjit_device = QJITDevice(capabilities, shots=1000, wires=2) assert "PauliX" in qjit_device.operations assert "PauliY" in qjit_device.observables diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index ddb385f83f..74e47bde0a 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -20,8 +20,9 @@ from catalyst import measure, qjit from catalyst.compiler import get_lib_path +from catalyst.device import extract_backend_info from catalyst.utils.exceptions import CompileError -from catalyst.utils.runtime import device_get_toml_config, extract_backend_info +from catalyst.utils.toml import get_device_capabilities # These have to match the ones in the configuration file. OPERATIONS = [ @@ -165,8 +166,8 @@ def get_c_interface(): return "DummyDevice", lib_path device = DummyDevice(wires=1) - config = device_get_toml_config(device) - backend_info = extract_backend_info(device, config) + capabilities = get_device_capabilities(device) + backend_info = extract_backend_info(device, capabilities) assert backend_info.kwargs["option1"] == 42 assert "option2" not in backend_info.kwargs diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py index 371d9e4455..d9d70eb67e 100644 --- a/frontend/test/pytest/test_debug.py +++ b/frontend/test/pytest/test_debug.py @@ -22,7 +22,7 @@ from catalyst.compiler import CompileOptions, Compiler from catalyst.debug import compile_from_mlir, get_cmain, print_compilation_stage from catalyst.utils.exceptions import CompileError -from catalyst.utils.runtime import get_lib_path +from catalyst.utils.runtime_environment import get_lib_path class TestDebugPrint: diff --git a/frontend/test/pytest/test_decomposition.py b/frontend/test/pytest/test_decomposition.py index 469c03d4a5..688e0bfc33 100644 --- a/frontend/test/pytest/test_decomposition.py +++ b/frontend/test/pytest/test_decomposition.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile +from copy import deepcopy import pennylane as qml import pytest from jax import numpy as jnp from catalyst import CompileError, ctrl, measure, qjit +from catalyst.utils.toml import ( + ProgramFeatures, + get_device_capabilities, + pennylane_operation_set, +) class CustomDevice(qml.QubitDevice): @@ -32,75 +36,58 @@ class CustomDevice(qml.QubitDevice): author = "Tester" lightning_device = qml.device("lightning.qubit", wires=0) - operations = lightning_device.operations.copy() - { - "MultiControlledX", - "Rot", - "S", - "C(Rot)", - "C(S)", - } - observables = lightning_device.observables.copy() - - config = None + backend_name = "default" backend_lib = "default" backend_kwargs = {} def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) - self.toml_file = None + program_features = ProgramFeatures(shots_present=self.shots is not None) + lightning_capabilities = get_device_capabilities(self.lightning_device, program_features) + custom_capabilities = deepcopy(lightning_capabilities) + custom_capabilities.native_ops.pop("Rot") + custom_capabilities.native_ops.pop("S") + custom_capabilities.to_decomp_ops.pop("MultiControlledX") + self.qjit_capabilities = custom_capabilities def apply(self, operations, **kwargs): """Unused""" raise RuntimeError("Only C/C++ interface is defined") - def __enter__(self, *args, **kwargs): - lightning_toml = self.lightning_device.config - with open(lightning_toml, mode="r", encoding="UTF-8") as f: - toml_contents = f.readlines() - - # TODO: update once schema 2 is merged - updated_toml_contents = [] - for line in toml_contents: - if '"MultiControlledX",' in line: - continue - if '"Rot",' in line: - continue - if '"S",' in line: - continue - - updated_toml_contents.append(line) - - self.toml_file = tempfile.NamedTemporaryFile(mode="w", delete=False) - self.toml_file.writelines(updated_toml_contents) - self.toml_file.close() # close for now without deleting + @property + def operations(self): + """Get PennyLane operations.""" + return ( + pennylane_operation_set(self.qjit_capabilities.native_ops) + | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops) + | pennylane_operation_set(self.qjit_capabilities.to_matrix_ops) + ) - self.config = self.toml_file.name - return self - - def __exit__(self, *args, **kwargs): - os.unlink(self.toml_file.name) - self.config = None + @property + def observables(self): + """Get PennyLane observables.""" + return pennylane_operation_set(self.qjit_capabilities.native_obs) @pytest.mark.parametrize("param,expected", [(0.0, True), (jnp.pi, False)]) def test_decomposition(param, expected): - with CustomDevice(wires=2) as dev: - - @qjit - @qml.qnode(dev) - def mid_circuit(x: float): - qml.Hadamard(wires=0) - qml.Rot(0, 0, x, wires=0) - qml.Hadamard(wires=0) - m = measure(wires=0) - b = m ^ 0x1 - qml.Hadamard(wires=1) - qml.Rot(0, 0, b * jnp.pi, wires=1) - qml.Hadamard(wires=1) - return measure(wires=1) - - assert mid_circuit(param) == expected + dev = CustomDevice(wires=2) + + @qjit + @qml.qnode(dev) + def mid_circuit(x: float): + qml.Hadamard(wires=0) + qml.Rot(0, 0, x, wires=0) + qml.Hadamard(wires=0) + m = measure(wires=0) + b = m ^ 0x1 + qml.Hadamard(wires=1) + qml.Rot(0, 0, b * jnp.pi, wires=1) + qml.Hadamard(wires=1) + return measure(wires=1) + + assert mid_circuit(param) == expected class TestControlledDecomposition: diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index 11eee25a7a..12b8a43010 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -25,9 +25,9 @@ from catalyst import qjit from catalyst.compiler import get_lib_path -from catalyst.device import QJITDeviceNewAPI +from catalyst.device import QJITDeviceNewAPI, extract_backend_info from catalyst.tracing.contexts import EvaluationContext, EvaluationMode -from catalyst.utils.runtime import device_get_toml_config, extract_backend_info +from catalyst.utils.toml import ProgramFeatures, get_device_capabilities class DummyDevice(Device): @@ -87,9 +87,9 @@ def test_qjit_device(): device = DummyDevice(wires=10, shots=2032) # Create qjit device - config = device_get_toml_config(device) - backend_info = extract_backend_info(device, config) - device_qjit = QJITDeviceNewAPI(device, backend_info) + capabilities = get_device_capabilities(device, ProgramFeatures(device.shots is not None)) + backend_info = extract_backend_info(device, capabilities) + device_qjit = QJITDeviceNewAPI(device, capabilities, backend_info) # Check attributes of the new device assert device_qjit.shots == qml.measurements.Shots(2032) @@ -123,13 +123,13 @@ def test_qjit_device_no_wires(): device = DummyDeviceNoWires(shots=2032) # Create qjit device - config = device_get_toml_config(device) - backend_info = extract_backend_info(device, config) + capabilities = get_device_capabilities(device, ProgramFeatures(device.shots is not None)) + backend_info = extract_backend_info(device, capabilities) with pytest.raises( AttributeError, match="Catalyst does not support devices without set wires." ): - QJITDeviceNewAPI(device, backend_info) + QJITDeviceNewAPI(device, capabilities, backend_info) def test_simple_circuit(): diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py index 9fc5bace89..df7bd4b822 100644 --- a/frontend/test/pytest/test_preprocess.py +++ b/frontend/test/pytest/test_preprocess.py @@ -412,7 +412,6 @@ def test_decomposition_of_cond_circuit(self): @qml.qjit @qml.qnode(dev) def circuit(phi: float): - OtherHadamard(wires=0) # define a conditional ansatz @@ -461,7 +460,6 @@ def test_decomposition_of_forloop_circuit(self, reps, angle): @qml.qjit @qml.qnode(dev) def circuit(n: int, x: float): - OtherHadamard(wires=0) def loop_rx(i, phi): @@ -495,7 +493,6 @@ def test_decomposition_of_whileloop_circuit(self, phi): @qml.qjit @qml.qnode(dev) def circuit(x: float): - @while_loop(lambda x: x < 2.0) def loop_rx(x): # perform some work and update (some of) the arguments