From 0ede63d1a8ec2145ee08df5efd9d214d051d0b35 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 3 May 2024 12:19:46 +0000 Subject: [PATCH 01/21] Make a test tomle-schema-2-compatible --- frontend/test/pytest/test_decomposition.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/frontend/test/pytest/test_decomposition.py b/frontend/test/pytest/test_decomposition.py index 469c03d4a5..df7d36efa2 100644 --- a/frontend/test/pytest/test_decomposition.py +++ b/frontend/test/pytest/test_decomposition.py @@ -59,14 +59,13 @@ def __enter__(self, *args, **kwargs): with open(lightning_toml, mode="r", encoding="UTF-8") as f: toml_contents = f.readlines() - # TODO: update once schema 2 is merged updated_toml_contents = [] for line in toml_contents: - if '"MultiControlledX",' in line: + if '"MultiControlledX",' in line or line.startswith("MultiControlledX "): continue - if '"Rot",' in line: + if '"Rot",' in line or line.startswith('Rot '): continue - if '"S",' in line: + if '"S",' in line or line.startswith('S '): continue updated_toml_contents.append(line) From 1babc99e798278672b568028a644618420f0e9f1 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 3 May 2024 12:23:47 +0000 Subject: [PATCH 02/21] Address formatting issues --- frontend/test/pytest/test_decomposition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/test_decomposition.py b/frontend/test/pytest/test_decomposition.py index df7d36efa2..a8a39218c5 100644 --- a/frontend/test/pytest/test_decomposition.py +++ b/frontend/test/pytest/test_decomposition.py @@ -63,9 +63,9 @@ def __enter__(self, *args, **kwargs): for line in toml_contents: if '"MultiControlledX",' in line or line.startswith("MultiControlledX "): continue - if '"Rot",' in line or line.startswith('Rot '): + if '"Rot",' in line or line.startswith("Rot "): continue - if '"S",' in line or line.startswith('S '): + if '"S",' in line or line.startswith("S "): continue updated_toml_contents.append(line) From 30de70cfbf60a7cdb7da2dfdc208b64fc00144bc Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Mon, 6 May 2024 12:13:44 +0000 Subject: [PATCH 03/21] Make test toml-schema-independant --- frontend/catalyst/compiler.py | 2 +- .../cuda/catalyst_to_cuda_interpreter.py | 2 +- frontend/catalyst/qfunc.py | 33 +- frontend/catalyst/qjit_device.py | 28 +- frontend/catalyst/utils/runtime.py | 63 +-- frontend/catalyst/utils/toml.py | 93 ++++- frontend/test/lit/test_decomposition.py | 364 +++++++++--------- frontend/test/lit/test_quantum_control.py | 181 +++++---- frontend/test/pytest/test_config_functions.py | 53 ++- frontend/test/pytest/test_custom_devices.py | 7 +- frontend/test/pytest/test_decomposition.py | 93 ++--- frontend/test/pytest/test_device_api.py | 17 +- 12 files changed, 480 insertions(+), 456 deletions(-) diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 0ca97dbc8b..0639f37610 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -33,7 +33,7 @@ from catalyst.utils.exceptions import CompileError from catalyst.utils.filesystem import Directory -from catalyst.utils.runtime import get_lib_path +from catalyst.utils.toml import get_lib_path package_root = os.path.dirname(__file__) diff --git a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py index 7201f1dfea..b84d8bf7dd 100644 --- a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py +++ b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py @@ -820,7 +820,7 @@ def get_jaxpr(self, *args): an MLIR module """ - def cudaq_backend_info(device, _config) -> BackendInfo: + def cudaq_backend_info(device, _capabilities) -> BackendInfo: """The extract_backend_info should not be run by the cuda compiler as it is catalyst-specific. We need to make this API a bit nicer for third-party compilers. """ diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index 813d06224e..9878a5c427 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -32,11 +32,14 @@ from catalyst.qjit_device import QJITDevice, QJITDeviceNewAPI from catalyst.utils.runtime import ( BackendInfo, - device_get_toml_config, extract_backend_info, - validate_config_with_device, + validate_device_capabilities, +) +from catalyst.utils.toml import ( + DeviceCapabilities, + ProgramFeatures, + get_device_capabilities, ) -from catalyst.utils.toml import TOMLDocument class QFunc: @@ -54,26 +57,34 @@ def __new__(cls): raise NotImplementedError() # pragma: no-cover @staticmethod - def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> BackendInfo: + def extract_backend_info( + device: qml.QubitDevice, capabilities: DeviceCapabilities + ) -> BackendInfo: """Wrapper around extract_backend_info in the runtime module.""" - return extract_backend_info(device, config) + return extract_backend_info(device, capabilities) # pylint: disable=no-member def __call__(self, *args, **kwargs): assert isinstance(self, qml.QNode) - config = device_get_toml_config(self.device) - validate_config_with_device(self.device, config) - backend_info = QFunc.extract_backend_info(self.device, config) + device = self.device + program_features = ProgramFeatures(device.shots is not None) + device_capabilities = get_device_capabilities(device, program_features) + backend_info = QFunc.extract_backend_info(device, device_capabilities) + + # Validate decive operations against the declared capabilities + validate_device_capabilities(device, device_capabilities) if isinstance(self.device, qml.devices.Device): - device = QJITDeviceNewAPI(self.device, backend_info) + self.qjit_device = QJITDeviceNewAPI(device, device_capabilities, backend_info) else: - device = QJITDevice(config, self.device.shots, self.device.wires, backend_info) + self.qjit_device = QJITDevice( + device_capabilities, device.shots, device.wires, backend_info + ) def _eval_quantum(*args): closed_jaxpr, out_type, out_tree = trace_quantum_function( - self.func, device, args, kwargs, qnode=self + self.func, self.qjit_device, args, kwargs, qnode=self ) args_expanded = get_implicit_and_explicit_flat_args(None, *args) res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded) diff --git a/frontend/catalyst/qjit_device.py b/frontend/catalyst/qjit_device.py index e52e8d487b..8c7c3da4f5 100644 --- a/frontend/catalyst/qjit_device.py +++ b/frontend/catalyst/qjit_device.py @@ -24,7 +24,7 @@ from catalyst.preprocess import catalyst_acceptance, decompose from catalyst.utils.exceptions import CompileError from catalyst.utils.patching import Patcher -from catalyst.utils.runtime import BackendInfo, device_get_toml_config +from catalyst.utils.runtime import BackendInfo from catalyst.utils.toml import ( DeviceCapabilities, OperationProperties, @@ -154,7 +154,7 @@ def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> S def __init__( self, - target_config: TOMLDocument, + original_device_capabilities: DeviceCapabilities, shots=None, wires=None, backend: Optional[BackendInfo] = None, @@ -164,23 +164,18 @@ def __init__( self.backend_name = backend.c_interface_name if backend else "default" self.backend_lib = backend.lpath if backend else "" self.backend_kwargs = backend.kwargs if backend else {} - device_name = backend.device_name if backend else "default" - program_features = ProgramFeatures(shots is not None) - target_device_capabilities = get_device_capabilities( - target_config, program_features, device_name - ) - self.capabilities = get_qjit_device_capabilities(target_device_capabilities) + self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities) @property def operations(self) -> Set[str]: """Get the device operations using PennyLane's syntax""" - return pennylane_operation_set(self.capabilities.native_ops) + return pennylane_operation_set(self.qjit_capabilities.native_ops) @property def observables(self) -> Set[str]: """Get the device observables""" - return pennylane_operation_set(self.capabilities.native_obs) + return pennylane_operation_set(self.qjit_capabilities.native_obs) def apply(self, operations, **kwargs): """ @@ -259,6 +254,7 @@ class QJITDeviceNewAPI(qml.devices.Device): def __init__( self, original_device, + original_device_capabilities: DeviceCapabilities, backend: Optional[BackendInfo] = None, ): self.original_device = original_device @@ -274,24 +270,18 @@ def __init__( self.backend_name = backend.c_interface_name if backend else "default" self.backend_lib = backend.lpath if backend else "" self.backend_kwargs = backend.kwargs if backend else {} - device_name = backend.device_name if backend else "default" - target_config = device_get_toml_config(original_device) - program_features = ProgramFeatures(original_device.shots is not None) - target_device_capabilities = get_device_capabilities( - target_config, program_features, device_name - ) - self.capabilities = get_qjit_device_capabilities(target_device_capabilities) + self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities) @property def operations(self) -> Set[str]: """Get the device operations""" - return pennylane_operation_set(self.capabilities.native_ops) + return pennylane_operation_set(self.qjit_capabilities.native_ops) @property def observables(self) -> Set[str]: """Get the device observables""" - return pennylane_operation_set(self.capabilities.native_obs) + return pennylane_operation_set(self.qjit_capabilities.native_obs) @property def measurement_processes(self) -> Set[str]: diff --git a/frontend/catalyst/utils/runtime.py b/frontend/catalyst/utils/runtime.py index 2d1dc0e8e6..ac5e71f66c 100644 --- a/frontend/catalyst/utils/runtime.py +++ b/frontend/catalyst/utils/runtime.py @@ -26,28 +26,17 @@ import pennylane as qml -from catalyst._configuration import INSTALLED from catalyst.utils.exceptions import CompileError from catalyst.utils.toml import ( + DeviceCapabilities, ProgramFeatures, TOMLDocument, get_device_capabilities, + get_lib_path, pennylane_operation_set, read_toml_file, ) -package_root = os.path.dirname(__file__) - - -# Default paths to dep libraries -DEFAULT_LIB_PATHS = { - "llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"), - "runtime": os.path.join(package_root, "../../../runtime/build/lib"), - "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"), - "oqc_runtime": os.path.join(package_root, "../../catalyst/oqc/src/build"), -} - - # TODO: This should be removed after implementing `get_c_interface` # for the following backend devices: SUPPORTED_RT_DEVICES = { @@ -58,13 +47,6 @@ } -def get_lib_path(project, env_var): - """Get the library path.""" - if INSTALLED: - return os.path.join(package_root, "..", "lib") # pragma: no cover - return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, "")) - - def check_no_overlap(*args, device_name): """Check items in *args are mutually exclusive. @@ -109,7 +91,9 @@ def is_not_adj(op): return set(operations_no_adj) -def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) -> None: +def validate_device_capabilities( + device: qml.QubitDevice, device_capabilities: DeviceCapabilities +) -> None: """Validate configuration document against the device attributes. Raise CompileError in case of mismatch: * If device is not qjit-compatible. @@ -125,15 +109,13 @@ def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) - Raises: CompileError """ - if not config["compilation"]["qjit_compatible"]: + if not device_capabilities.qjit_compatible_flag: raise CompileError( f"Attempting to compile program for incompatible device '{device.name}': " f"Config is not marked as qjit-compatible" ) device_name = device.short_name if isinstance(device, qml.Device) else device.name - program_features = ProgramFeatures(device.shots is not None) - device_capabilities = get_device_capabilities(config, program_features, device_name) native = pennylane_operation_set(device_capabilities.native_ops) decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops) @@ -163,34 +145,6 @@ def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) - ) -def device_get_toml_config(device) -> TOMLDocument: - """Get the contents of the device config file.""" - if hasattr(device, "config"): - # The expected case: device specifies its own config. - toml_file = device.config - else: - # TODO: Remove this section when `qml.Device`s are guaranteed to have their own config file - # field. - device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR")) - - name = device.short_name if isinstance(device, qml.Device) else device.name - # The toml files name convention we follow is to replace - # the dots with underscores in the device short name. - toml_file_name = name.replace(".", "_") + ".toml" - # And they are currently saved in the following directory. - toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name - - try: - config = read_toml_file(toml_file) - except FileNotFoundError as e: - raise CompileError( - "Attempting to compile program for incompatible device: " - f"Config file ({toml_file}) does not exist" - ) from e - - return config - - @dataclass class BackendInfo: """Backend information""" @@ -201,7 +155,7 @@ class BackendInfo: kwargs: Dict[str, Any] -def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> BackendInfo: +def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo: """Extract the backend info from a quantum device. The device is expected to carry a reference to a valid TOML config file.""" @@ -255,8 +209,7 @@ def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> Backe device._s3_folder # pylint: disable=protected-access ) - options = config.get("options", {}) - for k, v in options.items(): + for k, v in capabilities.options.items(): if hasattr(device, v): device_kwargs[k] = getattr(device, v) diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py index e629e822dd..d15b3bbed7 100644 --- a/frontend/catalyst/utils/toml.py +++ b/frontend/catalyst/utils/toml.py @@ -16,11 +16,16 @@ """ import importlib.util +import os +import pathlib from dataclasses import dataclass from functools import reduce from itertools import repeat -from typing import Any, Dict, List, Set +from typing import Any, Dict, List, Optional, Set +import pennylane as qml + +from catalyst._configuration import INSTALLED from catalyst.utils.exceptions import CompileError # TODO: @@ -48,6 +53,25 @@ from tomlkit.exceptions import TOMLKitError as TOMLException +package_root = os.path.dirname(__file__) + + +# Default paths to dep libraries +DEFAULT_LIB_PATHS = { + "llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"), + "runtime": os.path.join(package_root, "../../../runtime/build/lib"), + "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"), + "oqc_runtime": os.path.join(package_root, "../../catalyst/oqc/src/build"), +} + + +def get_lib_path(project, env_var): + """Get the library path.""" + if INSTALLED: + return os.path.join(package_root, "..", "lib") # pragma: no cover + return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, "")) + + def read_toml_file(toml_file: str) -> TOMLDocument: """Helper function opening toml file properly and reading it into a document""" with open(toml_file, "rb") as f: @@ -82,9 +106,11 @@ class DeviceCapabilities: # pylint: disable=too-many-instance-attributes to_matrix_ops: Dict[str, OperationProperties] native_obs: Dict[str, OperationProperties] measurement_processes: Set[str] + qjit_compatible_flag: bool mid_circuit_measurement_flag: bool runtime_code_generation_flag: bool dynamic_qubit_management_flag: bool + options: Dict[str, bool] def intersect_operations( @@ -112,11 +138,16 @@ class ProgramFeatures: shots_present: bool -def check_compilation_flag(config: TOMLDocument, flag_name: str) -> bool: - """Checks the flag in the toml document 'compilation' section.""" +def get_compilation_flag(config: TOMLDocument, flag_name: str) -> bool: + """Get the flag in the toml document 'compilation' section.""" return bool(config.get("compilation", {}).get(flag_name, False)) +def get_options(config: TOMLDocument) -> Dict[str, str]: + """Get custom options sections""" + return {str(k): str(v) for k, v in config.get("options", {}).items()} + + def check_quantum_control_flag(config: TOMLDocument) -> bool: """Check the control flag. Only exists in toml config schema 1""" schema = int(config["schema"]) @@ -301,7 +332,7 @@ def patch_schema1_collections( for op, props in native_gate_props.items(): props.controllable = op not in gates_to_be_decomposed_if_controlled - supports_adjoint = check_compilation_flag(config, "quantum_adjoint") + supports_adjoint = get_compilation_flag(config, "quantum_adjoint") if supports_adjoint: # Makr all gates as invertibles for props in native_gate_props.values(): @@ -327,10 +358,54 @@ def patch_schema1_collections( decomp_props.pop("ControlledPhaseShift") +def get_device_toml_config(device) -> TOMLDocument: + """Get the contents of the device config file.""" + if hasattr(device, "config"): + # The expected case: device specifies its own config. + toml_file = device.config + else: + # TODO: Remove this section when `qml.Device`s are guaranteed to have their own config file + # field. + device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR")) + + name = device.short_name if isinstance(device, qml.Device) else device.name + # The toml files name convention we follow is to replace + # the dots with underscores in the device short name. + toml_file_name = name.replace(".", "_") + ".toml" + # And they are currently saved in the following directory. + toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name + + try: + config = read_toml_file(toml_file) + except FileNotFoundError as e: + raise CompileError( + "Attempting to compile program for incompatible device: " + f"Config file ({toml_file}) does not exist" + ) from e + + return config + + def get_device_capabilities( + device, program_features: Optional[ProgramFeatures] = None +) -> DeviceCapabilities: + """Get or load DeviceCapabilities structure from device""" + + if hasattr(device, "qjit_capabilities"): + return device.qjit_capabilities + else: + program_features = ( + program_features if program_features else ProgramFeatures(device.shots is not None) + ) + device_name = device.short_name if isinstance(device, qml.Device) else device.name + device_config = get_device_toml_config(device) + return load_device_capabilities(device_config, program_features, device_name) + + +def load_device_capabilities( config: TOMLDocument, program_features: ProgramFeatures, device_name: str ) -> DeviceCapabilities: - """Load TOML document into the DeviceCapabilities structure""" + """Load device capabilities from device config""" schema = int(config["schema"]) @@ -370,7 +445,9 @@ def get_device_capabilities( to_matrix_ops=matrix_decomp_props, native_obs=observable_props, measurement_processes=measurements_props, - mid_circuit_measurement_flag=check_compilation_flag(config, "mid_circuit_measurement"), - runtime_code_generation_flag=check_compilation_flag(config, "runtime_code_generation"), - dynamic_qubit_management_flag=check_compilation_flag(config, "dynamic_qubit_management"), + qjit_compatible_flag=get_compilation_flag(config, "qjit_compatible"), + mid_circuit_measurement_flag=get_compilation_flag(config, "mid_circuit_measurement"), + runtime_code_generation_flag=get_compilation_flag(config, "runtime_code_generation"), + dynamic_qubit_management_flag=get_compilation_flag(config, "dynamic_qubit_management"), + options=get_options(config), ) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index c0375a4356..36a34799ab 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -17,11 +17,18 @@ import os import tempfile +from copy import deepcopy import jax import pennylane as qml from catalyst import cond, for_loop, measure, qjit, while_loop +from catalyst.utils.runtime import pennylane_operation_set +from catalyst.utils.toml import ( + DeviceCapabilities, + ProgramFeatures, + get_device_capabilities, +) def get_custom_device_without(num_wires, discards): @@ -37,8 +44,6 @@ class CustomDevice(qml.QubitDevice): author = "Tester" lightning_device = qml.device("lightning.qubit", wires=0) - operations = lightning_device.operations.copy() - discards - observables = lightning_device.observables.copy() config = None backend_name = "default" @@ -47,54 +52,55 @@ class CustomDevice(qml.QubitDevice): def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) - self.toml_file = None + program_features = ProgramFeatures(shots_present=self.shots is not None) + lightning_capabilities = get_device_capabilities( + self.lightning_device, program_features + ) + custom_capabilities = deepcopy(lightning_capabilities) + for gate in discards: + if gate in dummy_capabilities.native_ops: + custom_capabilities.native_ops.pop(gate) + if gate in dummy_capabilities.to_decomp_ops: + custom_capabilities.to_decomp_ops.pop(gate) + if gate in dummy_capabilities.to_matrix_ops: + custom_capabilities.to_matrix_ops.pop(gate) + self.qjit_capabilities = custom_capabilities def apply(self, operations, **kwargs): """Unused""" raise RuntimeError("Only C/C++ interface is defined") - def __enter__(self, *args, **kwargs): - lightning_toml = self.lightning_device.config - with open(lightning_toml, mode="r", encoding="UTF-8") as f: - toml_contents = f.readlines() + @property + def operations(self): + return ( + pennylane_operation_set(self.qjit_capabilities.native_ops) + | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops) + | pennylane_operation_set(self.qjit_capabilities.to_matrix_ops) + ) - # TODO: update once schema 2 is merged - updated_toml_contents = [] - for line in toml_contents: - if any(f'"{gate}",' in line for gate in discards): - continue - updated_toml_contents.append(line) - - self.toml_file = tempfile.NamedTemporaryFile(mode="w", delete=False) - self.toml_file.writelines(updated_toml_contents) - self.toml_file.close() # close for now without deleting - - self.config = self.toml_file.name - return self - - def __exit__(self, *args, **kwargs): - os.unlink(self.toml_file.name) - self.config = None + @property + def observables(self): + return pennylane_operation_set(self.qjit_capabilities.native_obs) return CustomDevice(wires=num_wires) def test_decompose_multicontrolledx(): """Test decomposition of MultiControlledX.""" - with get_custom_device_without(5, {"MultiControlledX"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_multicontrolled_x1 - def decompose_multicontrolled_x1(theta: float): - qml.RX(theta, wires=[0]) - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - return qml.state() + dev = get_custom_device_without(5, {"MultiControlledX"}) - print(decompose_multicontrolled_x1.mlir) + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_multicontrolled_x1 + def decompose_multicontrolled_x1(theta: float): + qml.RX(theta, wires=[0]) + # CHECK-NOT: name = "MultiControlledX" + # CHECK: quantum.unitary + # CHECK-NOT: name = "MultiControlledX" + qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) + return qml.state() + + print(decompose_multicontrolled_x1.mlir) test_decompose_multicontrolledx() @@ -102,25 +108,25 @@ def decompose_multicontrolled_x1(theta: float): def test_decompose_multicontrolledx_in_conditional(): """Test decomposition of MultiControlledX in conditional.""" - with get_custom_device_without(5, {"MultiControlledX"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: @jit_decompose_multicontrolled_x2 - def decompose_multicontrolled_x2(theta: float, n: int): - qml.RX(theta, wires=[0]) - - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - @cond(n > 1) - def cond_fn(): - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) + dev = get_custom_device_without(5, {"MultiControlledX"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: @jit_decompose_multicontrolled_x2 + def decompose_multicontrolled_x2(theta: float, n: int): + qml.RX(theta, wires=[0]) + + # CHECK-NOT: name = "MultiControlledX" + # CHECK: quantum.unitary + # CHECK-NOT: name = "MultiControlledX" + @cond(n > 1) + def cond_fn(): + qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - cond_fn() - return qml.state() + cond_fn() + return qml.state() - print(decompose_multicontrolled_x2.mlir) + print(decompose_multicontrolled_x2.mlir) test_decompose_multicontrolledx_in_conditional() @@ -128,26 +134,26 @@ def cond_fn(): def test_decompose_multicontrolledx_in_while_loop(): """Test decomposition of MultiControlledX in while loop.""" - with get_custom_device_without(5, {"MultiControlledX"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: @jit_decompose_multicontrolled_x3 - def decompose_multicontrolled_x3(theta: float, n: int): - qml.RX(theta, wires=[0]) - - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - @while_loop(lambda v: v[0] < 10) - def loop(v): - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - return v[0] + 1, v[1] + dev = get_custom_device_without(5, {"MultiControlledX"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: @jit_decompose_multicontrolled_x3 + def decompose_multicontrolled_x3(theta: float, n: int): + qml.RX(theta, wires=[0]) + + # CHECK-NOT: name = "MultiControlledX" + # CHECK: quantum.unitary + # CHECK-NOT: name = "MultiControlledX" + @while_loop(lambda v: v[0] < 10) + def loop(v): + qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) + return v[0] + 1, v[1] - loop((0, n)) - return qml.state() + loop((0, n)) + return qml.state() - print(decompose_multicontrolled_x3.mlir) + print(decompose_multicontrolled_x3.mlir) test_decompose_multicontrolledx_in_while_loop() @@ -155,25 +161,25 @@ def loop(v): def test_decompose_multicontrolledx_in_for_loop(): """Test decomposition of MultiControlledX in for loop.""" - with get_custom_device_without(5, {"MultiControlledX"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: @jit_decompose_multicontrolled_x4 - def decompose_multicontrolled_x4(theta: float, n: int): - qml.RX(theta, wires=[0]) - - # CHECK-NOT: name = "MultiControlledX" - # CHECK: quantum.unitary - # CHECK-NOT: name = "MultiControlledX" - @for_loop(0, n, 1) - def loop(_): - qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) + dev = get_custom_device_without(5, {"MultiControlledX"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: @jit_decompose_multicontrolled_x4 + def decompose_multicontrolled_x4(theta: float, n: int): + qml.RX(theta, wires=[0]) + + # CHECK-NOT: name = "MultiControlledX" + # CHECK: quantum.unitary + # CHECK-NOT: name = "MultiControlledX" + @for_loop(0, n, 1) + def loop(_): + qml.MultiControlledX(wires=[0, 1, 2, 3], work_wires=[4]) - loop() - return qml.state() + loop() + return qml.state() - print(decompose_multicontrolled_x4.mlir) + print(decompose_multicontrolled_x4.mlir) test_decompose_multicontrolledx_in_for_loop() @@ -181,29 +187,29 @@ def loop(_): def test_decompose_rot(): """Test decomposition of Rot gate.""" - with get_custom_device_without(1, {"Rot", "C(Rot)"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_rot - def decompose_rot(phi: float, theta: float, omega: float): - # CHECK-NOT: name = "Rot" - # CHECK: [[phi:%.+]] = tensor.extract %arg0 - # CHECK-NOT: name = "Rot" - # CHECK: {{%.+}} = quantum.custom "RZ"([[phi]]) - # CHECK-NOT: name = "Rot" - # CHECK: [[theta:%.+]] = tensor.extract %arg1 - # CHECK-NOT: name = "Rot" - # CHECK: {{%.+}} = quantum.custom "RY"([[theta]]) - # CHECK-NOT: name = "Rot" - # CHECK: [[omega:%.+]] = tensor.extract %arg2 - # CHECK-NOT: name = "Rot" - # CHECK: {{%.+}} = quantum.custom "RZ"([[omega]]) - # CHECK-NOT: name = "Rot" - qml.Rot(phi, theta, omega, wires=0) - return measure(wires=0) - - print(decompose_rot.mlir) + dev = get_custom_device_without(1, {"Rot", "C(Rot)"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_rot + def decompose_rot(phi: float, theta: float, omega: float): + # CHECK-NOT: name = "Rot" + # CHECK: [[phi:%.+]] = tensor.extract %arg0 + # CHECK-NOT: name = "Rot" + # CHECK: {{%.+}} = quantum.custom "RZ"([[phi]]) + # CHECK-NOT: name = "Rot" + # CHECK: [[theta:%.+]] = tensor.extract %arg1 + # CHECK-NOT: name = "Rot" + # CHECK: {{%.+}} = quantum.custom "RY"([[theta]]) + # CHECK-NOT: name = "Rot" + # CHECK: [[omega:%.+]] = tensor.extract %arg2 + # CHECK-NOT: name = "Rot" + # CHECK: {{%.+}} = quantum.custom "RZ"([[omega]]) + # CHECK-NOT: name = "Rot" + qml.Rot(phi, theta, omega, wires=0) + return measure(wires=0) + + print(decompose_rot.mlir) test_decompose_rot() @@ -211,21 +217,21 @@ def decompose_rot(phi: float, theta: float, omega: float): def test_decompose_s(): """Test decomposition of S gate.""" - with get_custom_device_without(1, {"S", "C(S)"}) as dev: + dev = get_custom_device_without(1, {"S", "C(S)"}) - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_s - def decompose_s(): - # CHECK-NOT: name="S" - # CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64 - # CHECK-NOT: name = "S" - # CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]]) - # CHECK-NOT: name = "S" - qml.S(wires=0) - return measure(wires=0) + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_s + def decompose_s(): + # CHECK-NOT: name="S" + # CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64 + # CHECK-NOT: name = "S" + # CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]]) + # CHECK-NOT: name = "S" + qml.S(wires=0) + return measure(wires=0) - print(decompose_s.mlir) + print(decompose_s.mlir) test_decompose_s() @@ -233,21 +239,21 @@ def decompose_s(): def test_decompose_qubitunitary(): """Test decomposition of QubitUnitary""" - with get_custom_device_without(1, {"QubitUnitary"}) as dev: + dev = get_custom_device_without(1, {"QubitUnitary"}) - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_qubit_unitary - def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)): - # CHECK-NOT: name = "QubitUnitary" - # CHECK: quantum.custom "RZ" - # CHECK: quantum.custom "RY" - # CHECK: quantum.custom "RZ" - # CHECK-NOT: name = "QubitUnitary" - qml.QubitUnitary(U, wires=0) - return measure(wires=0) + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_qubit_unitary + def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)): + # CHECK-NOT: name = "QubitUnitary" + # CHECK: quantum.custom "RZ" + # CHECK: quantum.custom "RY" + # CHECK: quantum.custom "RZ" + # CHECK-NOT: name = "QubitUnitary" + qml.QubitUnitary(U, wires=0) + return measure(wires=0) - print(decompose_qubit_unitary.mlir) + print(decompose_qubit_unitary.mlir) test_decompose_qubitunitary() @@ -255,47 +261,47 @@ def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)): def test_decompose_singleexcitationplus(): """Test decomposition of single excitation plus.""" - with get_custom_device_without(2, {"SingleExcitationPlus", "C(SingleExcitationPlus)"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_decompose_singleexcitationplus - def decompose_singleexcitationplus(theta: float): - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[a_scalar_tensor_float_2:%.+]] = stablehlo.constant dense<2.{{[0]+}}e+00> - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s0q1:%.+]] = quantum.custom "PauliX" - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s0q0:%.+]] = quantum.custom "PauliX" - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[a_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[a_theta_div_2_scalar:%.+]] = tensor.extract [[a_theta_div_2]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s1:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[a_theta_div_2_scalar]]) [[s0q1]], [[s0q0]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s2q1:%.+]] = quantum.custom "PauliX"() [[s1]]#1 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s2q0:%.+]] = quantum.custom "PauliX"() [[s1]]#0 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[b_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[b_theta_div_2_scalar:%.+]] = tensor.extract [[b_theta_div_2]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s3:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[b_theta_div_2_scalar]]) [[s2q1]], [[s2q0]] - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s4:%.+]]:2 = quantum.custom "CNOT"() [[s3]]#0, [[s3]]#1 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[theta_scalar:%.+]] = tensor.extract %arg0 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s5:%.+]]:2 = quantum.custom "CRY"([[theta_scalar]]) [[s4]]#1, [[s4]]#0 - # CHECK-NOT: name = "SingleExcitationPlus" - # CHECK: [[s6:%.+]]:2 = quantum.custom "CNOT"() [[s5]]#1, [[s5]]#0 - # CHECK-NOT: name = "SingleExcitationPlus" - qml.SingleExcitationPlus(theta, wires=[0, 1]) - return measure(wires=0) - - print(decompose_singleexcitationplus.mlir) + dev = get_custom_device_without(2, {"SingleExcitationPlus", "C(SingleExcitationPlus)"}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_decompose_singleexcitationplus + def decompose_singleexcitationplus(theta: float): + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[a_scalar_tensor_float_2:%.+]] = stablehlo.constant dense<2.{{[0]+}}e+00> + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s0q1:%.+]] = quantum.custom "PauliX" + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s0q0:%.+]] = quantum.custom "PauliX" + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[a_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[a_theta_div_2_scalar:%.+]] = tensor.extract [[a_theta_div_2]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s1:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[a_theta_div_2_scalar]]) [[s0q1]], [[s0q0]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s2q1:%.+]] = quantum.custom "PauliX"() [[s1]]#1 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s2q0:%.+]] = quantum.custom "PauliX"() [[s1]]#0 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[b_theta_div_2:%.+]] = stablehlo.divide %arg0, [[a_scalar_tensor_float_2]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[b_theta_div_2_scalar:%.+]] = tensor.extract [[b_theta_div_2]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s3:%.+]]:2 = quantum.custom "ControlledPhaseShift"([[b_theta_div_2_scalar]]) [[s2q1]], [[s2q0]] + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s4:%.+]]:2 = quantum.custom "CNOT"() [[s3]]#0, [[s3]]#1 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[theta_scalar:%.+]] = tensor.extract %arg0 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s5:%.+]]:2 = quantum.custom "CRY"([[theta_scalar]]) [[s4]]#1, [[s4]]#0 + # CHECK-NOT: name = "SingleExcitationPlus" + # CHECK: [[s6:%.+]]:2 = quantum.custom "CNOT"() [[s5]]#1, [[s5]]#0 + # CHECK-NOT: name = "SingleExcitationPlus" + qml.SingleExcitationPlus(theta, wires=[0, 1]) + return measure(wires=0) + + print(decompose_singleexcitationplus.mlir) test_decompose_singleexcitationplus() diff --git a/frontend/test/lit/test_quantum_control.py b/frontend/test/lit/test_quantum_control.py index e7125e7eb7..0ee8d26233 100644 --- a/frontend/test/lit/test_quantum_control.py +++ b/frontend/test/lit/test_quantum_control.py @@ -17,11 +17,19 @@ import os import tempfile +from copy import deepcopy import jax.numpy as jnp import pennylane as qml from catalyst import qjit +from catalyst.utils.runtime import pennylane_operation_set +from catalyst.utils.toml import ( + DeviceCapabilities, + OperationProperties, + ProgramFeatures, + get_device_capabilities, +) def get_custom_qjit_device(num_wires, discards, additions): @@ -37,70 +45,59 @@ class CustomDevice(qml.QubitDevice): author = "Tester" lightning_device = qml.device("lightning.qubit", wires=0) - operations = lightning_device.operations.copy() - discards | additions - observables = lightning_device.observables.copy() - config = None backend_name = "default" backend_lib = "default" backend_kwargs = {} def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) - self.toml_file = None + program_features = ProgramFeatures(shots_present=shots is not None) + lightning_capabilities = get_device_capabilities( + self.lightning_device, program_features + ) + custom_capabilities = deepcopy(lightning_capabilities) + for gate in discards: + custom_capabilities.native_ops.pop(gate) + custom_capabilities.native_ops.update(additions) + self.qjit_capabilities = custom_capabilities + + @property + def operations(self): + return ( + pennylane_operation_set(self.qjit_capabilities.native_ops) + | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops) + | pennylane_operation_set(self.qjit_capabilities.to_matrix_ops) + ) + + @property + def observables(self): + return pennylane_operation_set(self.qjit_capabilities.native_obs) def apply(self, operations, **kwargs): """Unused""" raise RuntimeError("Only C/C++ interface is defined") - def __enter__(self, *args, **kwargs): - lightning_toml = self.lightning_device.config - with open(lightning_toml, mode="r", encoding="UTF-8") as f: - toml_contents = f.readlines() - - # TODO: update once schema 2 is merged - updated_toml_contents = [] - for line in toml_contents: - if any(f'"{gate}",' in line for gate in discards): - continue - - updated_toml_contents.append(line) - if "native = [" in line: - for gate in additions: - if not gate.startswith("C("): - updated_toml_contents.append(f' "{gate}",\n') - - self.toml_file = tempfile.NamedTemporaryFile(mode="w", delete=False) - self.toml_file.writelines(updated_toml_contents) - self.toml_file.close() # close for now without deleting - - self.config = self.toml_file.name - return self - - def __exit__(self, *args, **kwargs): - os.unlink(self.toml_file.name) - self.config = None - return CustomDevice(wires=num_wires) def test_named_controlled(): """Test that named-controlled operations are passed as-is.""" - with get_custom_qjit_device(2, set(), set()) as dev: + dev = get_custom_qjit_device(2, set(), set()) - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_named_controlled - def named_controlled(): - # CHECK: quantum.custom "CNOT" - qml.CNOT(wires=[0, 1]) - # CHECK: quantum.custom "CY" - qml.CY(wires=[0, 1]) - # CHECK: quantum.custom "CZ" - qml.CZ(wires=[0, 1]) - return qml.state() + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_named_controlled + def named_controlled(): + # CHECK: quantum.custom "CNOT" + qml.CNOT(wires=[0, 1]) + # CHECK: quantum.custom "CY" + qml.CY(wires=[0, 1]) + # CHECK: quantum.custom "CZ" + qml.CZ(wires=[0, 1]) + return qml.state() - print(named_controlled.mlir) + print(named_controlled.mlir) test_named_controlled() @@ -108,19 +105,19 @@ def named_controlled(): def test_native_controlled_custom(): """Test native control of a custom operation.""" - with get_custom_qjit_device(3, {"CRot"}, {"Rot", "C(Rot)"}) as dev: + dev = get_custom_qjit_device(3, set(), {"Rot": OperationProperties(True, True, False)}) - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_native_controlled - def native_controlled(): - # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot" - # CHECK-SAME: ctrls - # CHECK-SAME: ctrlvals(%true, %true) - qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2]) - return qml.state() + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_native_controlled + def native_controlled(): + # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot" + # CHECK-SAME: ctrls + # CHECK-SAME: ctrlvals(%true, %true) + qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2]) + return qml.state() - print(native_controlled.mlir) + print(native_controlled.mlir) test_native_controlled_custom() @@ -128,31 +125,31 @@ def native_controlled(): def test_native_controlled_unitary(): """Test native control of the unitary operation.""" - with get_custom_qjit_device(4, set(), set()) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_native_controlled_unitary - def native_controlled_unitary(): - # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:3 = quantum.unitary - # CHECK-SAME: ctrls - # CHECK-SAME: ctrlvals(%true, %true, %true) - qml.ctrl( - qml.QubitUnitary( - jnp.array( - [ - [0.70710678 + 0.0j, 0.70710678 + 0.0j], - [0.70710678 + 0.0j, -0.70710678 + 0.0j], - ], - dtype=jnp.complex128, - ), - wires=[0], + dev = get_custom_qjit_device(4, set(), set()) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_native_controlled_unitary + def native_controlled_unitary(): + # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:3 = quantum.unitary + # CHECK-SAME: ctrls + # CHECK-SAME: ctrlvals(%true, %true, %true) + qml.ctrl( + qml.QubitUnitary( + jnp.array( + [ + [0.70710678 + 0.0j, 0.70710678 + 0.0j], + [0.70710678 + 0.0j, -0.70710678 + 0.0j], + ], + dtype=jnp.complex128, ), - control=[1, 2, 3], - ) - return qml.state() + wires=[0], + ), + control=[1, 2, 3], + ) + return qml.state() - print(native_controlled_unitary.mlir) + print(native_controlled_unitary.mlir) test_native_controlled_unitary() @@ -160,19 +157,19 @@ def native_controlled_unitary(): def test_native_controlled_multirz(): """Test native control of the multirz operation.""" - with get_custom_qjit_device(3, set(), {"C(MultiRZ)"}) as dev: - - @qjit(target="mlir") - @qml.qnode(dev) - # CHECK-LABEL: public @jit_native_controlled_multirz - def native_controlled_multirz(): - # CHECK: [[out:%.+]]:2, [[out_ctrl:%.+]] = quantum.multirz - # CHECK-SAME: ctrls - # CHECK-SAME: ctrlvals(%true) - qml.ctrl(qml.MultiRZ(0.6, wires=[0, 2]), control=[1]) - return qml.state() - - print(native_controlled_multirz.mlir) + dev = get_custom_qjit_device(3, set(), {"MultiRZ": OperationProperties(True, True, True)}) + + @qjit(target="mlir") + @qml.qnode(dev) + # CHECK-LABEL: public @jit_native_controlled_multirz + def native_controlled_multirz(): + # CHECK: [[out:%.+]]:2, [[out_ctrl:%.+]] = quantum.multirz + # CHECK-SAME: ctrls + # CHECK-SAME: ctrlvals(%true) + qml.ctrl(qml.MultiRZ(0.6, wires=[0, 2]), control=[1]) + return qml.state() + + print(native_controlled_multirz.mlir) test_native_controlled_multirz() diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py index 15897c272e..f28546fe59 100644 --- a/frontend/test/pytest/test_config_functions.py +++ b/frontend/test/pytest/test_config_functions.py @@ -23,19 +23,17 @@ from catalyst.qjit_device import QJITDevice, QJITDeviceNewAPI from catalyst.utils.exceptions import CompileError -from catalyst.utils.runtime import ( - check_no_overlap, - get_device_capabilities, - validate_config_with_device, -) +from catalyst.utils.runtime import check_no_overlap, validate_device_capabilities from catalyst.utils.toml import ( DeviceCapabilities, ProgramFeatures, TOMLDocument, check_quantum_control_flag, get_decomposable_gates, + get_device_capabilities, get_matrix_decomposable_gates, get_native_ops, + load_device_capabilities, pennylane_operation_set, read_toml_file, ) @@ -76,34 +74,30 @@ def get_test_device_capabilities( ) -> DeviceCapabilities: """Parse test config into the DeviceCapabilities structure""" config = get_test_config(config_text) - device_capabilities = get_device_capabilities(config, program_features, "dummy") + device_capabilities = load_device_capabilities(config, program_features, "dummy") return device_capabilities @pytest.mark.parametrize("schema", ALL_SCHEMAS) -def test_validate_config_with_device(schema): +def test_config_qjit_incompatible_device(schema): """Test error is raised if checking for qjit compatibility and field is false in toml file.""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - f""" - schema = {schema} - [compilation] - qjit_compatible = false - """ - ) - ) - - config = read_toml_file(toml_file) + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + f""" + schema = {schema} + [compilation] + qjit_compatible = false + """ + ), + ) - device = DummyDevice() - with pytest.raises( - CompileError, - match=f"Attempting to compile program for incompatible device '{device.name}'", - ): - validate_config_with_device(device, config) + device = DummyDevice() + with pytest.raises( + CompileError, + match=f"Attempting to compile program for incompatible device '{device.name}'", + ): + validate_device_capabilities(device, device_capabilities) def test_get_observables_schema1(): @@ -385,7 +379,8 @@ def test_config_invalid_condition_duplicate(shots): def test_config_qjit_device_operations(): """Check the gate condition handling logic""" - config = get_test_config( + capabilities = get_test_device_capabilities( + ProgramFeatures(False), dedent( r""" schema = 2 @@ -396,7 +391,7 @@ def test_config_qjit_device_operations(): """ ), ) - qjit_device = QJITDevice(config, shots=1000, wires=2) + qjit_device = QJITDevice(capabilities, shots=1000, wires=2) assert "PauliX" in qjit_device.operations assert "PauliY" in qjit_device.observables diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index e984933b1e..50d2d36d7a 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -21,7 +21,8 @@ from catalyst import measure, qjit from catalyst.compiler import get_lib_path from catalyst.utils.exceptions import CompileError -from catalyst.utils.runtime import device_get_toml_config, extract_backend_info +from catalyst.utils.runtime import extract_backend_info +from catalyst.utils.toml import get_device_capabilities, get_device_toml_config # These have to match the ones in the configuration file. OPERATIONS = [ @@ -166,8 +167,8 @@ def get_c_interface(): return "DummyDevice", get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/libdummy_device.so" device = DummyDevice(wires=1) - config = device_get_toml_config(device) - backend_info = extract_backend_info(device, config) + capabilities = get_device_capabilities(device) + backend_info = extract_backend_info(device, capabilities) assert backend_info.kwargs["option1"] == 42 assert "option2" not in backend_info.kwargs diff --git a/frontend/test/pytest/test_decomposition.py b/frontend/test/pytest/test_decomposition.py index a8a39218c5..c17269f9f8 100644 --- a/frontend/test/pytest/test_decomposition.py +++ b/frontend/test/pytest/test_decomposition.py @@ -14,12 +14,19 @@ import os import tempfile +from copy import deepcopy import pennylane as qml import pytest from jax import numpy as jnp from catalyst import CompileError, ctrl, measure, qjit +from catalyst.utils.runtime import pennylane_operation_set +from catalyst.utils.toml import ( + DeviceCapabilities, + ProgramFeatures, + get_device_capabilities, +) class CustomDevice(qml.QubitDevice): @@ -32,74 +39,56 @@ class CustomDevice(qml.QubitDevice): author = "Tester" lightning_device = qml.device("lightning.qubit", wires=0) - operations = lightning_device.operations.copy() - { - "MultiControlledX", - "Rot", - "S", - "C(Rot)", - "C(S)", - } - observables = lightning_device.observables.copy() - - config = None + backend_name = "default" backend_lib = "default" backend_kwargs = {} def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) - self.toml_file = None + program_features = ProgramFeatures(shots_present=self.shots is not None) + lightning_capabilities = get_device_capabilities(self.lightning_device, program_features) + custom_capabilities = deepcopy(lightning_capabilities) + custom_capabilities.native_ops.pop("Rot") + custom_capabilities.native_ops.pop("S") + custom_capabilities.to_decomp_ops.pop("MultiControlledX") + self.qjit_capabilities = custom_capabilities def apply(self, operations, **kwargs): """Unused""" raise RuntimeError("Only C/C++ interface is defined") - def __enter__(self, *args, **kwargs): - lightning_toml = self.lightning_device.config - with open(lightning_toml, mode="r", encoding="UTF-8") as f: - toml_contents = f.readlines() - - updated_toml_contents = [] - for line in toml_contents: - if '"MultiControlledX",' in line or line.startswith("MultiControlledX "): - continue - if '"Rot",' in line or line.startswith("Rot "): - continue - if '"S",' in line or line.startswith("S "): - continue - - updated_toml_contents.append(line) - - self.toml_file = tempfile.NamedTemporaryFile(mode="w", delete=False) - self.toml_file.writelines(updated_toml_contents) - self.toml_file.close() # close for now without deleting + @property + def operations(self): + return ( + pennylane_operation_set(self.qjit_capabilities.native_ops) + | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops) + | pennylane_operation_set(self.qjit_capabilities.to_matrix_ops) + ) - self.config = self.toml_file.name - return self - - def __exit__(self, *args, **kwargs): - os.unlink(self.toml_file.name) - self.config = None + @property + def observables(self): + return pennylane_operation_set(self.qjit_capabilities.native_obs) @pytest.mark.parametrize("param,expected", [(0.0, True), (jnp.pi, False)]) def test_decomposition(param, expected): - with CustomDevice(wires=2) as dev: - - @qjit - @qml.qnode(dev) - def mid_circuit(x: float): - qml.Hadamard(wires=0) - qml.Rot(0, 0, x, wires=0) - qml.Hadamard(wires=0) - m = measure(wires=0) - b = m ^ 0x1 - qml.Hadamard(wires=1) - qml.Rot(0, 0, b * jnp.pi, wires=1) - qml.Hadamard(wires=1) - return measure(wires=1) - - assert mid_circuit(param) == expected + dev = CustomDevice(wires=2) + + @qjit + @qml.qnode(dev) + def mid_circuit(x: float): + qml.Hadamard(wires=0) + qml.Rot(0, 0, x, wires=0) + qml.Hadamard(wires=0) + m = measure(wires=0) + b = m ^ 0x1 + qml.Hadamard(wires=1) + qml.Rot(0, 0, b * jnp.pi, wires=1) + qml.Hadamard(wires=1) + return measure(wires=1) + + assert mid_circuit(param) == expected class TestControlledDecomposition: diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index d363cb144a..589bebccd4 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -26,7 +26,12 @@ from catalyst.compiler import get_lib_path from catalyst.qjit_device import QJITDeviceNewAPI from catalyst.tracing.contexts import EvaluationContext, EvaluationMode -from catalyst.utils.runtime import device_get_toml_config, extract_backend_info +from catalyst.utils.runtime import extract_backend_info +from catalyst.utils.toml import ( + ProgramFeatures, + get_device_capabilities, + get_device_toml_config, +) class DummyDevice(Device): @@ -91,9 +96,9 @@ def test_qjit_device(): device = DummyDevice(wires=10, shots=2032) # Create qjit device - config = device_get_toml_config(device) - backend_info = extract_backend_info(device, config) - device_qjit = QJITDeviceNewAPI(device, backend_info) + capabilities = get_device_capabilities(device, ProgramFeatures(device.shots is not None)) + backend_info = extract_backend_info(device, capabilities) + device_qjit = QJITDeviceNewAPI(device, capabilities, backend_info) # Check attributes of the new device assert device_qjit.shots == qml.measurements.Shots(2032) @@ -131,8 +136,8 @@ def test_qjit_device_no_wires(): device = DummyDeviceNoWires(shots=2032) # Create qjit device - config = device_get_toml_config(device) - backend_info = extract_backend_info(device, config) + capabilities = get_device_capabilities(device, ProgramFeatures(device.shots is not None)) + backend_info = extract_backend_info(device, capabilities) with pytest.raises( AttributeError, match="Catalyst does not support devices without set wires." From f97e7e92ca02cc70ab86a06dca8b805ea2504b42 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Mon, 6 May 2024 12:15:53 +0000 Subject: [PATCH 04/21] Address codecov errors --- frontend/test/lit/test_decomposition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 36a34799ab..86ca743ac1 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -58,11 +58,11 @@ def __init__(self, shots=None, wires=None): ) custom_capabilities = deepcopy(lightning_capabilities) for gate in discards: - if gate in dummy_capabilities.native_ops: + if gate in custom_capabilities.native_ops: custom_capabilities.native_ops.pop(gate) - if gate in dummy_capabilities.to_decomp_ops: + if gate in custom_capabilities.to_decomp_ops: custom_capabilities.to_decomp_ops.pop(gate) - if gate in dummy_capabilities.to_matrix_ops: + if gate in custom_capabilities.to_matrix_ops: custom_capabilities.to_matrix_ops.pop(gate) self.qjit_capabilities = custom_capabilities From 40499b4b133398abf906922b8cf30d91ed5aa19c Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Mon, 6 May 2024 12:26:45 +0000 Subject: [PATCH 05/21] Address codecov errors --- frontend/catalyst/qfunc.py | 2 +- frontend/catalyst/qjit_device.py | 3 --- frontend/catalyst/utils/runtime.py | 4 ---- frontend/test/lit/test_decomposition.py | 9 +++------ frontend/test/lit/test_quantum_control.py | 4 ++-- frontend/test/pytest/test_config_functions.py | 1 - frontend/test/pytest/test_custom_devices.py | 2 +- frontend/test/pytest/test_decomposition.py | 10 +++------- 8 files changed, 10 insertions(+), 25 deletions(-) diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index 9878a5c427..d8243e521e 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -63,7 +63,7 @@ def extract_backend_info( """Wrapper around extract_backend_info in the runtime module.""" return extract_backend_info(device, capabilities) - # pylint: disable=no-member + # pylint: disable=no-member, attribute-defined-outside-init def __call__(self, *args, **kwargs): assert isinstance(self, qml.QNode) diff --git a/frontend/catalyst/qjit_device.py b/frontend/catalyst/qjit_device.py index 8c7c3da4f5..14001b45fb 100644 --- a/frontend/catalyst/qjit_device.py +++ b/frontend/catalyst/qjit_device.py @@ -28,9 +28,6 @@ from catalyst.utils.toml import ( DeviceCapabilities, OperationProperties, - ProgramFeatures, - TOMLDocument, - get_device_capabilities, intersect_operations, pennylane_operation_set, ) diff --git a/frontend/catalyst/utils/runtime.py b/frontend/catalyst/utils/runtime.py index ac5e71f66c..84dab4c783 100644 --- a/frontend/catalyst/utils/runtime.py +++ b/frontend/catalyst/utils/runtime.py @@ -29,12 +29,8 @@ from catalyst.utils.exceptions import CompileError from catalyst.utils.toml import ( DeviceCapabilities, - ProgramFeatures, - TOMLDocument, - get_device_capabilities, get_lib_path, pennylane_operation_set, - read_toml_file, ) # TODO: This should be removed after implementing `get_c_interface` diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 86ca743ac1..7f771da4e8 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -15,7 +15,6 @@ # RUN: %PYTHON %s | FileCheck %s # pylint: disable=line-too-long -import os import tempfile from copy import deepcopy @@ -24,11 +23,7 @@ from catalyst import cond, for_loop, measure, qjit, while_loop from catalyst.utils.runtime import pennylane_operation_set -from catalyst.utils.toml import ( - DeviceCapabilities, - ProgramFeatures, - get_device_capabilities, -) +from catalyst.utils.toml import ProgramFeatures, get_device_capabilities def get_custom_device_without(num_wires, discards): @@ -72,6 +67,7 @@ def apply(self, operations, **kwargs): @property def operations(self): + """Return operations using PennyLane's C(.) syntax""" return ( pennylane_operation_set(self.qjit_capabilities.native_ops) | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops) @@ -80,6 +76,7 @@ def operations(self): @property def observables(self): + """Return PennyLane observables""" return pennylane_operation_set(self.qjit_capabilities.native_obs) return CustomDevice(wires=num_wires) diff --git a/frontend/test/lit/test_quantum_control.py b/frontend/test/lit/test_quantum_control.py index 0ee8d26233..4268fc4a12 100644 --- a/frontend/test/lit/test_quantum_control.py +++ b/frontend/test/lit/test_quantum_control.py @@ -16,7 +16,6 @@ """ Test the lowering cases involving quantum control """ import os -import tempfile from copy import deepcopy import jax.numpy as jnp @@ -25,7 +24,6 @@ from catalyst import qjit from catalyst.utils.runtime import pennylane_operation_set from catalyst.utils.toml import ( - DeviceCapabilities, OperationProperties, ProgramFeatures, get_device_capabilities, @@ -64,6 +62,7 @@ def __init__(self, shots=None, wires=None): @property def operations(self): + """Get PennyLane operations.""" return ( pennylane_operation_set(self.qjit_capabilities.native_ops) | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops) @@ -72,6 +71,7 @@ def operations(self): @property def observables(self): + """Get PennyLane observables.""" return pennylane_operation_set(self.qjit_capabilities.native_obs) def apply(self, operations, **kwargs): diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py index f28546fe59..0bdffde782 100644 --- a/frontend/test/pytest/test_config_functions.py +++ b/frontend/test/pytest/test_config_functions.py @@ -30,7 +30,6 @@ TOMLDocument, check_quantum_control_flag, get_decomposable_gates, - get_device_capabilities, get_matrix_decomposable_gates, get_native_ops, load_device_capabilities, diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index 50d2d36d7a..e179c1053c 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -22,7 +22,7 @@ from catalyst.compiler import get_lib_path from catalyst.utils.exceptions import CompileError from catalyst.utils.runtime import extract_backend_info -from catalyst.utils.toml import get_device_capabilities, get_device_toml_config +from catalyst.utils.toml import get_device_capabilities # These have to match the ones in the configuration file. OPERATIONS = [ diff --git a/frontend/test/pytest/test_decomposition.py b/frontend/test/pytest/test_decomposition.py index c17269f9f8..f4163fe5f3 100644 --- a/frontend/test/pytest/test_decomposition.py +++ b/frontend/test/pytest/test_decomposition.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile from copy import deepcopy import pennylane as qml @@ -22,11 +20,7 @@ from catalyst import CompileError, ctrl, measure, qjit from catalyst.utils.runtime import pennylane_operation_set -from catalyst.utils.toml import ( - DeviceCapabilities, - ProgramFeatures, - get_device_capabilities, -) +from catalyst.utils.toml import ProgramFeatures, get_device_capabilities class CustomDevice(qml.QubitDevice): @@ -60,6 +54,7 @@ def apply(self, operations, **kwargs): @property def operations(self): + """Get PennyLane operations.""" return ( pennylane_operation_set(self.qjit_capabilities.native_ops) | pennylane_operation_set(self.qjit_capabilities.to_decomp_ops) @@ -68,6 +63,7 @@ def operations(self): @property def observables(self): + """Get PennyLane observables.""" return pennylane_operation_set(self.qjit_capabilities.native_obs) From 0ed92b64ac63e0c93715a902fdeb41f74503c32b Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Mon, 6 May 2024 12:28:46 +0000 Subject: [PATCH 06/21] Address codecov errors --- frontend/test/lit/test_decomposition.py | 1 - frontend/test/lit/test_quantum_control.py | 1 - frontend/test/pytest/test_device_api.py | 6 +----- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 7f771da4e8..cf0477599a 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -15,7 +15,6 @@ # RUN: %PYTHON %s | FileCheck %s # pylint: disable=line-too-long -import tempfile from copy import deepcopy import jax diff --git a/frontend/test/lit/test_quantum_control.py b/frontend/test/lit/test_quantum_control.py index 4268fc4a12..b8aef68302 100644 --- a/frontend/test/lit/test_quantum_control.py +++ b/frontend/test/lit/test_quantum_control.py @@ -15,7 +15,6 @@ # RUN: %PYTHON %s | FileCheck %s """ Test the lowering cases involving quantum control """ -import os from copy import deepcopy import jax.numpy as jnp diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index 589bebccd4..cdc1fdac2b 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -27,11 +27,7 @@ from catalyst.qjit_device import QJITDeviceNewAPI from catalyst.tracing.contexts import EvaluationContext, EvaluationMode from catalyst.utils.runtime import extract_backend_info -from catalyst.utils.toml import ( - ProgramFeatures, - get_device_capabilities, - get_device_toml_config, -) +from catalyst.utils.toml import ProgramFeatures, get_device_capabilities class DummyDevice(Device): From d00c7e46f06d6a488bdafe60f4dc021e47b342d8 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Thu, 9 May 2024 10:25:11 +0000 Subject: [PATCH 07/21] Fix wrong field name --- frontend/catalyst/device/qjit_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index 69c99e8f04..43c963c62f 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -293,7 +293,7 @@ def observables(self) -> Set[str]: @property def measurement_processes(self) -> Set[str]: """Get the device measurement processes""" - return self.capabilities.measurement_processes + return self.qjit_capabilities.measurement_processes def preprocess( self, From d3ffae59fd2e0f79edc99bd58e5a45381311452b Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Thu, 16 May 2024 11:48:21 +0000 Subject: [PATCH 08/21] Address review suggestions; Move code around --- frontend/catalyst/compiler.py | 2 +- .../cuda/catalyst_to_cuda_interpreter.py | 2 +- frontend/catalyst/device/__init__.py | 12 +- frontend/catalyst/device/qjit_device.py | 91 +++++++- frontend/catalyst/qfunc.py | 13 +- frontend/catalyst/utils/runtime.py | 212 ------------------ frontend/catalyst/utils/toml.py | 119 ++++++++-- frontend/test/lit/test_decomposition.py | 2 +- frontend/test/pytest/test_config_functions.py | 2 +- frontend/test/pytest/test_custom_devices.py | 2 +- frontend/test/pytest/test_debug.py | 2 +- frontend/test/pytest/test_decomposition.py | 2 +- frontend/test/pytest/test_device_api.py | 2 +- 13 files changed, 210 insertions(+), 253 deletions(-) delete mode 100644 frontend/catalyst/utils/runtime.py diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 06c8ad93db..8968cdd424 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -33,7 +33,7 @@ from catalyst.utils.exceptions import CompileError from catalyst.utils.filesystem import Directory -from catalyst.utils.toml import get_lib_path +from catalyst.utils.paths import get_lib_path package_root = os.path.dirname(__file__) diff --git a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py index b84d8bf7dd..c53a6d0c55 100644 --- a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py +++ b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py @@ -39,6 +39,7 @@ import pennylane as qml from jax.tree_util import tree_unflatten +from catalyst.device import BackendInfo from catalyst.jax_primitives import ( AbstractObs, adjoint_p, @@ -75,7 +76,6 @@ from catalyst.qfunc import QFunc from catalyst.utils.exceptions import CompileError from catalyst.utils.patching import Patcher -from catalyst.utils.runtime import BackendInfo from .primitives import ( cuda_inst, diff --git a/frontend/catalyst/device/__init__.py b/frontend/catalyst/device/__init__.py index 6b9c1bd55c..9835db81ce 100644 --- a/frontend/catalyst/device/__init__.py +++ b/frontend/catalyst/device/__init__.py @@ -16,9 +16,11 @@ Internal API for the device module. """ -from catalyst.device.qjit_device import QJITDevice, QJITDeviceNewAPI - -__all__ = ( - "QJITDevice", - "QJITDeviceNewAPI", +from catalyst.device.qjit_device import ( + BackendInfo, + QJITDevice, + QJITDeviceNewAPI, + extract_backend_info, ) + +__all__ = ("QJITDevice", "QJITDeviceNewAPI", "BackendInfo", "extract_backend_info") diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index 43c963c62f..f325a03ba3 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -16,10 +16,13 @@ This module contains device stubs for the old and new PennyLane device API, which facilitate the application of decomposition and other device pre-processing routines. """ - +import os +import pathlib +import platform from copy import deepcopy +from dataclasses import dataclass from functools import partial -from typing import Optional, Set +from typing import Any, Dict, Optional, Set import pennylane as qml from pennylane.measurements import MidMeasureMP @@ -32,7 +35,7 @@ ) from catalyst.utils.exceptions import CompileError from catalyst.utils.patching import Patcher -from catalyst.utils.runtime import BackendInfo +from catalyst.utils.paths import get_lib_path from catalyst.utils.toml import ( DeviceCapabilities, OperationProperties, @@ -80,6 +83,88 @@ for op in RUNTIME_OPERATIONS } +from catalyst.utils.paths import get_lib_path + +# TODO: This should be removed after implementing `get_c_interface` +# for the following backend devices: +SUPPORTED_RT_DEVICES = { + "lightning.qubit": ("LightningSimulator", "librtd_lightning"), + "lightning.kokkos": ("LightningKokkosSimulator", "librtd_lightning"), + "braket.aws.qubit": ("OpenQasmDevice", "librtd_openqasm"), + "braket.local.qubit": ("OpenQasmDevice", "librtd_openqasm"), +} + + +@dataclass +class BackendInfo: + """Backend information""" + + device_name: str + c_interface_name: str + lpath: str + kwargs: Dict[str, Any] + + +def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo: + """Extract the backend info from a quantum device. The device is expected to carry a reference + to a valid TOML config file.""" + + dname = device.name + if isinstance(device, qml.Device): + dname = device.short_name + + device_name = "" + device_lpath = "" + device_kwargs = {} + + if dname in SUPPORTED_RT_DEVICES: + # Support backend devices without `get_c_interface` + device_name = SUPPORTED_RT_DEVICES[dname][0] + device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR") + sys_platform = platform.system() + + if sys_platform == "Linux": + device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so") + elif sys_platform == "Darwin": # pragma: no cover + device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib") + else: # pragma: no cover + raise NotImplementedError(f"Platform not supported: {sys_platform}") + elif hasattr(device, "get_c_interface"): + # Support third party devices with `get_c_interface` + device_name, device_lpath = device.get_c_interface() + else: + raise CompileError(f"The {dname} device does not provide C interface for compilation.") + + if not pathlib.Path(device_lpath).is_file(): + raise CompileError(f"Device at {device_lpath} cannot be found!") + + if hasattr(device, "shots"): + if isinstance(device, qml.Device): + device_kwargs["shots"] = device.shots if device.shots else 0 + else: + # TODO: support shot vectors + device_kwargs["shots"] = device.shots.total_shots if device.shots else 0 + + if dname == "braket.local.qubit": # pragma: no cover + device_kwargs["device_type"] = dname + device_kwargs["backend"] = ( + # pylint: disable=protected-access + device._device._delegate.DEVICE_ID + ) + elif dname == "braket.aws.qubit": # pragma: no cover + device_kwargs["device_type"] = dname + device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access + if device._s3_folder: # pylint: disable=protected-access + device_kwargs["s3_destination_folder"] = str( + device._s3_folder # pylint: disable=protected-access + ) + + for k, v in capabilities.options.items(): + if hasattr(device, v): + device_kwargs[k] = getattr(device, v) + + return BackendInfo(dname, device_name, device_lpath, device_kwargs) + def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Set[str]: """Calculate the set of supported quantum gates for the QJIT device from the gates diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index add222fa99..2ce2bd9481 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -22,7 +22,12 @@ from jax.core import eval_jaxpr from jax.tree_util import tree_flatten, tree_unflatten -from catalyst.device import QJITDevice, QJITDeviceNewAPI +from catalyst.device import ( + BackendInfo, + QJITDevice, + QJITDeviceNewAPI, + extract_backend_info, +) from catalyst.jax_extras import ( deduce_avals, get_implicit_and_explicit_flat_args, @@ -30,15 +35,11 @@ ) from catalyst.jax_primitives import func_p from catalyst.jax_tracer import trace_quantum_function -from catalyst.utils.runtime import ( - BackendInfo, - extract_backend_info, - validate_device_capabilities, -) from catalyst.utils.toml import ( DeviceCapabilities, ProgramFeatures, get_device_capabilities, + validate_device_capabilities, ) diff --git a/frontend/catalyst/utils/runtime.py b/frontend/catalyst/utils/runtime.py deleted file mode 100644 index 84dab4c783..0000000000 --- a/frontend/catalyst/utils/runtime.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Runtime utility methods. -""" - -# pylint: disable=too-many-branches - -import os -import pathlib -import platform -import re -from dataclasses import dataclass -from typing import Any, Dict - -import pennylane as qml - -from catalyst.utils.exceptions import CompileError -from catalyst.utils.toml import ( - DeviceCapabilities, - get_lib_path, - pennylane_operation_set, -) - -# TODO: This should be removed after implementing `get_c_interface` -# for the following backend devices: -SUPPORTED_RT_DEVICES = { - "lightning.qubit": ("LightningSimulator", "librtd_lightning"), - "lightning.kokkos": ("LightningKokkosSimulator", "librtd_lightning"), - "braket.aws.qubit": ("OpenQasmDevice", "librtd_openqasm"), - "braket.local.qubit": ("OpenQasmDevice", "librtd_openqasm"), -} - - -def check_no_overlap(*args, device_name): - """Check items in *args are mutually exclusive. - - Args: - *args (List[Str]): List of strings. - device_name (str): Device name for error reporting. - - Raises: - CompileError - """ - set_of_sets = [set(arg) for arg in args] - union = set.union(*set_of_sets) - len_of_sets = [len(arg) for arg in args] - if sum(len_of_sets) == len(union): - return - - overlaps = set() - for s in set_of_sets: - overlaps.update(s - union) - union = union - s - - msg = f"Device '{device_name}' has overlapping gates: {overlaps}" - raise CompileError(msg) - - -def filter_out_adjoint(operations): - """Remove Adjoint from operations. - - Args: - operations (List[Str]): List of strings with names of supported operations - - Returns: - List: A list of strings with names of supported operations with Adjoint and C gates - removed. - """ - adjoint = re.compile(r"^Adjoint\(.*\)$") - - def is_not_adj(op): - return not re.match(adjoint, op) - - operations_no_adj = filter(is_not_adj, operations) - return set(operations_no_adj) - - -def validate_device_capabilities( - device: qml.QubitDevice, device_capabilities: DeviceCapabilities -) -> None: - """Validate configuration document against the device attributes. - Raise CompileError in case of mismatch: - * If device is not qjit-compatible. - * If configuration file does not exists. - * If decomposable, matrix, and native gates have some overlap. - * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and - ``device.observables``. - - Args: - device (qml.Device): An instance of a quantum device. - config (TOMLDocument): A TOML document representation. - - Raises: CompileError - """ - - if not device_capabilities.qjit_compatible_flag: - raise CompileError( - f"Attempting to compile program for incompatible device '{device.name}': " - f"Config is not marked as qjit-compatible" - ) - - device_name = device.short_name if isinstance(device, qml.Device) else device.name - - native = pennylane_operation_set(device_capabilities.native_ops) - decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops) - matrix = pennylane_operation_set(device_capabilities.to_matrix_ops) - - check_no_overlap(native, decomposable, matrix, device_name=device_name) - - if hasattr(device, "operations") and hasattr(device, "observables"): - # For gates, we require strict match - device_gates = filter_out_adjoint(set(device.operations)) - spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable)) - if device_gates != spec_gates: - raise CompileError( - "Gates in qml.device.operations and specification file do not match.\n" - f"Gates that present only in the device: {device_gates - spec_gates}\n" - f"Gates that present only in spec: {spec_gates - device_gates}\n" - ) - - # For observables, we do not have `non-native` section in the config, so we check that - # device data supercedes the specification. - device_observables = set(device.observables) - spec_observables = pennylane_operation_set(device_capabilities.native_obs) - if (spec_observables - device_observables) != set(): - raise CompileError( - "Observables in qml.device.observables and specification file do not match.\n" - f"Observables that present only in spec: {spec_observables - device_observables}\n" - ) - - -@dataclass -class BackendInfo: - """Backend information""" - - device_name: str - c_interface_name: str - lpath: str - kwargs: Dict[str, Any] - - -def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo: - """Extract the backend info from a quantum device. The device is expected to carry a reference - to a valid TOML config file.""" - - dname = device.name - if isinstance(device, qml.Device): - dname = device.short_name - - device_name = "" - device_lpath = "" - device_kwargs = {} - - if dname in SUPPORTED_RT_DEVICES: - # Support backend devices without `get_c_interface` - device_name = SUPPORTED_RT_DEVICES[dname][0] - device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR") - sys_platform = platform.system() - - if sys_platform == "Linux": - device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so") - elif sys_platform == "Darwin": # pragma: no cover - device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib") - else: # pragma: no cover - raise NotImplementedError(f"Platform not supported: {sys_platform}") - elif hasattr(device, "get_c_interface"): - # Support third party devices with `get_c_interface` - device_name, device_lpath = device.get_c_interface() - else: - raise CompileError(f"The {dname} device does not provide C interface for compilation.") - - if not pathlib.Path(device_lpath).is_file(): - raise CompileError(f"Device at {device_lpath} cannot be found!") - - if hasattr(device, "shots"): - if isinstance(device, qml.Device): - device_kwargs["shots"] = device.shots if device.shots else 0 - else: - # TODO: support shot vectors - device_kwargs["shots"] = device.shots.total_shots if device.shots else 0 - - if dname == "braket.local.qubit": # pragma: no cover - device_kwargs["device_type"] = dname - device_kwargs["backend"] = ( - # pylint: disable=protected-access - device._device._delegate.DEVICE_ID - ) - elif dname == "braket.aws.qubit": # pragma: no cover - device_kwargs["device_type"] = dname - device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access - if device._s3_folder: # pylint: disable=protected-access - device_kwargs["s3_destination_folder"] = str( - device._s3_folder # pylint: disable=protected-access - ) - - for k, v in capabilities.options.items(): - if hasattr(device, v): - device_kwargs[k] = getattr(device, v) - - return BackendInfo(dname, device_name, device_lpath, device_kwargs) diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py index 376fd4d4ad..df49a47cf7 100644 --- a/frontend/catalyst/utils/toml.py +++ b/frontend/catalyst/utils/toml.py @@ -18,6 +18,7 @@ import importlib.util import os import pathlib +import re from dataclasses import dataclass from functools import reduce from itertools import repeat @@ -27,6 +28,7 @@ from catalyst._configuration import INSTALLED from catalyst.utils.exceptions import CompileError +from catalyst.utils.paths import get_lib_path # TODO: # Once Python version 3.11 is the oldest supported Python version, we can remove tomlkit @@ -53,25 +55,6 @@ from tomlkit.exceptions import TOMLKitError as TOMLException -package_root = os.path.dirname(__file__) - - -# Default paths to dep libraries -DEFAULT_LIB_PATHS = { - "llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"), - "runtime": os.path.join(package_root, "../../../runtime/build/lib"), - "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"), - "oqc_runtime": os.path.join(package_root, "../../catalyst/oqc/src/build"), -} - - -def get_lib_path(project, env_var): - """Get the library path.""" - if INSTALLED: - return os.path.join(package_root, "..", "lib") # pragma: no cover - return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, "")) - - def read_toml_file(toml_file: str) -> TOMLDocument: """Helper function opening toml file properly and reading it into a document""" with open(toml_file, "rb") as f: @@ -450,3 +433,101 @@ def load_device_capabilities( dynamic_qubit_management_flag=get_compilation_flag(config, "dynamic_qubit_management"), options=get_options(config), ) + + +def check_no_overlap(*args, device_name): + """Check items in *args are mutually exclusive. + + Args: + *args (List[Str]): List of strings. + device_name (str): Device name for error reporting. + + Raises: + CompileError + """ + set_of_sets = [set(arg) for arg in args] + union = set.union(*set_of_sets) + len_of_sets = [len(arg) for arg in args] + if sum(len_of_sets) == len(union): + return + + overlaps = set() + for s in set_of_sets: + overlaps.update(s - union) + union = union - s + + msg = f"Device '{device_name}' has overlapping gates: {overlaps}" + raise CompileError(msg) + + +def filter_out_adjoint(operations): + """Remove Adjoint from operations. + + Args: + operations (List[Str]): List of strings with names of supported operations + + Returns: + List: A list of strings with names of supported operations with Adjoint and C gates + removed. + """ + adjoint = re.compile(r"^Adjoint\(.*\)$") + + def is_not_adj(op): + return not re.match(adjoint, op) + + operations_no_adj = filter(is_not_adj, operations) + return set(operations_no_adj) + + +def validate_device_capabilities( + device: qml.QubitDevice, device_capabilities: DeviceCapabilities +) -> None: + """Validate configuration document against the device attributes. + Raise CompileError in case of mismatch: + * If device is not qjit-compatible. + * If configuration file does not exists. + * If decomposable, matrix, and native gates have some overlap. + * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and + ``device.observables``. + + Args: + device (qml.Device): An instance of a quantum device. + config (TOMLDocument): A TOML document representation. + + Raises: CompileError + """ + + if not device_capabilities.qjit_compatible_flag: + raise CompileError( + f"Attempting to compile program for incompatible device '{device.name}': " + f"Config is not marked as qjit-compatible" + ) + + device_name = device.short_name if isinstance(device, qml.Device) else device.name + + native = pennylane_operation_set(device_capabilities.native_ops) + decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops) + matrix = pennylane_operation_set(device_capabilities.to_matrix_ops) + + check_no_overlap(native, decomposable, matrix, device_name=device_name) + + if hasattr(device, "operations") and hasattr(device, "observables"): + # For gates, we require strict match + device_gates = filter_out_adjoint(set(device.operations)) + spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable)) + if device_gates != spec_gates: + raise CompileError( + "Gates in qml.device.operations and specification file do not match.\n" + f"Gates that present only in the device: {device_gates - spec_gates}\n" + f"Gates that present only in spec: {spec_gates - device_gates}\n" + ) + + # For observables, we do not have `non-native` section in the config, so we check that + # device data supercedes the specification. + device_observables = set(device.observables) + spec_observables = pennylane_operation_set(device_capabilities.native_obs) + if (spec_observables - device_observables) != set(): + raise CompileError( + "Observables in qml.device.observables and specification file do not match.\n" + f"Observables that present only in spec: {spec_observables - device_observables}\n" + ) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index cf0477599a..c098665c08 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -21,7 +21,7 @@ import pennylane as qml from catalyst import cond, for_loop, measure, qjit, while_loop -from catalyst.utils.runtime import pennylane_operation_set +from catalyst.utils.toml import pennylane_operation_set from catalyst.utils.toml import ProgramFeatures, get_device_capabilities diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py index 9b222c0982..3282508a5d 100644 --- a/frontend/test/pytest/test_config_functions.py +++ b/frontend/test/pytest/test_config_functions.py @@ -23,7 +23,7 @@ from catalyst.device import QJITDevice from catalyst.utils.exceptions import CompileError -from catalyst.utils.runtime import check_no_overlap, validate_device_capabilities +from catalyst.utils.toml import check_no_overlap, validate_device_capabilities from catalyst.utils.toml import ( DeviceCapabilities, ProgramFeatures, diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index 533711907d..4e2dea2987 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -21,7 +21,7 @@ from catalyst import measure, qjit from catalyst.compiler import get_lib_path from catalyst.utils.exceptions import CompileError -from catalyst.utils.runtime import extract_backend_info +from catalyst.device import extract_backend_info from catalyst.utils.toml import get_device_capabilities # These have to match the ones in the configuration file. diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py index 371d9e4455..69c113ac00 100644 --- a/frontend/test/pytest/test_debug.py +++ b/frontend/test/pytest/test_debug.py @@ -22,7 +22,7 @@ from catalyst.compiler import CompileOptions, Compiler from catalyst.debug import compile_from_mlir, get_cmain, print_compilation_stage from catalyst.utils.exceptions import CompileError -from catalyst.utils.runtime import get_lib_path +from catalyst.utils.paths import get_lib_path class TestDebugPrint: diff --git a/frontend/test/pytest/test_decomposition.py b/frontend/test/pytest/test_decomposition.py index f4163fe5f3..f15f9094f4 100644 --- a/frontend/test/pytest/test_decomposition.py +++ b/frontend/test/pytest/test_decomposition.py @@ -19,7 +19,7 @@ from jax import numpy as jnp from catalyst import CompileError, ctrl, measure, qjit -from catalyst.utils.runtime import pennylane_operation_set +from catalyst.utils.toml import pennylane_operation_set from catalyst.utils.toml import ProgramFeatures, get_device_capabilities diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index 4d136de319..bd1cfe0cbc 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -27,7 +27,7 @@ from catalyst.compiler import get_lib_path from catalyst.device import QJITDeviceNewAPI from catalyst.tracing.contexts import EvaluationContext, EvaluationMode -from catalyst.utils.runtime import extract_backend_info +from catalyst.device import extract_backend_info from catalyst.utils.toml import ProgramFeatures, get_device_capabilities From 6537b2a8d86837bb9722e4faa8d47dda05175f19 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Thu, 16 May 2024 12:16:56 +0000 Subject: [PATCH 09/21] Address formatting issues --- frontend/catalyst/autograph/ag_primitives.py | 3 ++- frontend/test/lit/test_decomposition.py | 7 +++++-- frontend/test/pytest/test_config_functions.py | 3 ++- frontend/test/pytest/test_custom_devices.py | 2 +- frontend/test/pytest/test_decomposition.py | 7 +++++-- frontend/test/pytest/test_device_api.py | 3 +-- frontend/test/pytest/test_preprocess.py | 3 --- 7 files changed, 16 insertions(+), 12 deletions(-) diff --git a/frontend/catalyst/autograph/ag_primitives.py b/frontend/catalyst/autograph/ag_primitives.py index d132641ce4..31ae3eec33 100644 --- a/frontend/catalyst/autograph/ag_primitives.py +++ b/frontend/catalyst/autograph/ag_primitives.py @@ -272,7 +272,8 @@ def for_stmt( # to succeed, for example because they forgot to use a list instead of an array # pylint: disable=multiple-statements,missing-class-docstring - class EmptyResult: ... + class EmptyResult: + ... results = EmptyResult() fallback = False diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index c098665c08..2c1c0ce57b 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -21,8 +21,11 @@ import pennylane as qml from catalyst import cond, for_loop, measure, qjit, while_loop -from catalyst.utils.toml import pennylane_operation_set -from catalyst.utils.toml import ProgramFeatures, get_device_capabilities +from catalyst.utils.toml import ( + ProgramFeatures, + get_device_capabilities, + pennylane_operation_set, +) def get_custom_device_without(num_wires, discards): diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py index 3282508a5d..7a7bc0f8c1 100644 --- a/frontend/test/pytest/test_config_functions.py +++ b/frontend/test/pytest/test_config_functions.py @@ -23,11 +23,11 @@ from catalyst.device import QJITDevice from catalyst.utils.exceptions import CompileError -from catalyst.utils.toml import check_no_overlap, validate_device_capabilities from catalyst.utils.toml import ( DeviceCapabilities, ProgramFeatures, TOMLDocument, + check_no_overlap, check_quantum_control_flag, get_decomposable_gates, get_matrix_decomposable_gates, @@ -35,6 +35,7 @@ load_device_capabilities, pennylane_operation_set, read_toml_file, + validate_device_capabilities, ) diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index 4e2dea2987..74e47bde0a 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -20,8 +20,8 @@ from catalyst import measure, qjit from catalyst.compiler import get_lib_path -from catalyst.utils.exceptions import CompileError from catalyst.device import extract_backend_info +from catalyst.utils.exceptions import CompileError from catalyst.utils.toml import get_device_capabilities # These have to match the ones in the configuration file. diff --git a/frontend/test/pytest/test_decomposition.py b/frontend/test/pytest/test_decomposition.py index f15f9094f4..688e0bfc33 100644 --- a/frontend/test/pytest/test_decomposition.py +++ b/frontend/test/pytest/test_decomposition.py @@ -19,8 +19,11 @@ from jax import numpy as jnp from catalyst import CompileError, ctrl, measure, qjit -from catalyst.utils.toml import pennylane_operation_set -from catalyst.utils.toml import ProgramFeatures, get_device_capabilities +from catalyst.utils.toml import ( + ProgramFeatures, + get_device_capabilities, + pennylane_operation_set, +) class CustomDevice(qml.QubitDevice): diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index 81c38a44f4..d31bae2bcf 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -25,9 +25,8 @@ from catalyst import qjit from catalyst.compiler import get_lib_path -from catalyst.device import QJITDeviceNewAPI +from catalyst.device import QJITDeviceNewAPI, extract_backend_info from catalyst.tracing.contexts import EvaluationContext, EvaluationMode -from catalyst.device import extract_backend_info from catalyst.utils.toml import ProgramFeatures, get_device_capabilities diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py index 9fc5bace89..df7bd4b822 100644 --- a/frontend/test/pytest/test_preprocess.py +++ b/frontend/test/pytest/test_preprocess.py @@ -412,7 +412,6 @@ def test_decomposition_of_cond_circuit(self): @qml.qjit @qml.qnode(dev) def circuit(phi: float): - OtherHadamard(wires=0) # define a conditional ansatz @@ -461,7 +460,6 @@ def test_decomposition_of_forloop_circuit(self, reps, angle): @qml.qjit @qml.qnode(dev) def circuit(n: int, x: float): - OtherHadamard(wires=0) def loop_rx(i, phi): @@ -495,7 +493,6 @@ def test_decomposition_of_whileloop_circuit(self, phi): @qml.qjit @qml.qnode(dev) def circuit(x: float): - @while_loop(lambda x: x < 2.0) def loop_rx(x): # perform some work and update (some of) the arguments From 468c3caf6fb3815d3dad7348403236b8387f63e2 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Thu, 16 May 2024 12:21:44 +0000 Subject: [PATCH 10/21] Address formatting issues --- frontend/catalyst/autograph/ag_primitives.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/catalyst/autograph/ag_primitives.py b/frontend/catalyst/autograph/ag_primitives.py index 31ae3eec33..d132641ce4 100644 --- a/frontend/catalyst/autograph/ag_primitives.py +++ b/frontend/catalyst/autograph/ag_primitives.py @@ -272,8 +272,7 @@ def for_stmt( # to succeed, for example because they forgot to use a list instead of an array # pylint: disable=multiple-statements,missing-class-docstring - class EmptyResult: - ... + class EmptyResult: ... results = EmptyResult() fallback = False From 72c9f24b09b60d4e3f0f0de749e72b468098ebd6 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Thu, 16 May 2024 12:30:17 +0000 Subject: [PATCH 11/21] Address pylint issues --- frontend/catalyst/device/qjit_device.py | 3 +-- frontend/catalyst/jax_primitives.py | 2 +- frontend/catalyst/third_party/cuda/primitives/__init__.py | 2 +- frontend/catalyst/utils/runtime.py | 1 + frontend/catalyst/utils/toml.py | 5 +---- frontend/test/lit/test_quantum_control.py | 2 +- 6 files changed, 6 insertions(+), 9 deletions(-) diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index f3a96892f4..a5d0b1a203 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -84,8 +84,6 @@ for op in RUNTIME_OPERATIONS } -from catalyst.utils.paths import get_lib_path - # TODO: This should be removed after implementing `get_c_interface` # for the following backend devices: SUPPORTED_RT_DEVICES = { @@ -109,6 +107,7 @@ class BackendInfo: def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo: """Extract the backend info from a quantum device. The device is expected to carry a reference to a valid TOML config file.""" + # pylint: disable=too-many-branches dname = device.name if isinstance(device, qml.Device): diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 7d50fea3d3..2044a8fcd6 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -76,7 +76,7 @@ from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp from catalyst.utils.types import convert_shaped_arrays_to_tensors -# pylint: disable=unused-argument,abstract-method,too-many-lines +# pylint: disable=unused-argument,too-many-lines ######### # Types # diff --git a/frontend/catalyst/third_party/cuda/primitives/__init__.py b/frontend/catalyst/third_party/cuda/primitives/__init__.py index ead6053b29..e1679d3bcb 100644 --- a/frontend/catalyst/third_party/cuda/primitives/__init__.py +++ b/frontend/catalyst/third_party/cuda/primitives/__init__.py @@ -24,7 +24,7 @@ # We disable protected access in particular to avoid warnings with cudaq._pycuda. # And we disable unused-argument to avoid unused arguments in abstract_eval, particularly kwargs. -# pylint: disable=protected-access,unused-argument,abstract-method,line-too-long +# pylint: disable=protected-access,unused-argument,line-too-long class AbsCudaQState(jax.core.AbstractValue): diff --git a/frontend/catalyst/utils/runtime.py b/frontend/catalyst/utils/runtime.py index 1fb3aed313..8547ad3ccf 100644 --- a/frontend/catalyst/utils/runtime.py +++ b/frontend/catalyst/utils/runtime.py @@ -124,6 +124,7 @@ def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) - Raises: CompileError """ + # pylint: disable=too-many-function-args if not config["compilation"]["qjit_compatible"]: raise CompileError( diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py index df49a47cf7..899eb7b7c7 100644 --- a/frontend/catalyst/utils/toml.py +++ b/frontend/catalyst/utils/toml.py @@ -16,7 +16,6 @@ """ import importlib.util -import os import pathlib import re from dataclasses import dataclass @@ -26,7 +25,6 @@ import pennylane as qml -from catalyst._configuration import INSTALLED from catalyst.utils.exceptions import CompileError from catalyst.utils.paths import get_lib_path @@ -40,8 +38,7 @@ tomlkit = importlib.util.find_spec("tomlkit") # We need at least one of these to make sure we can read toml files. if tomllib is None and tomlkit is None: # pragma: nocover - msg = "Either tomllib or tomlkit need to be installed." - raise ImportError(msg) + raise ImportError("Either tomllib or tomlkit need to be installed.") # Give preference to tomllib if tomllib: # pragma: nocover diff --git a/frontend/test/lit/test_quantum_control.py b/frontend/test/lit/test_quantum_control.py index b8aef68302..15cc876f93 100644 --- a/frontend/test/lit/test_quantum_control.py +++ b/frontend/test/lit/test_quantum_control.py @@ -21,11 +21,11 @@ import pennylane as qml from catalyst import qjit -from catalyst.utils.runtime import pennylane_operation_set from catalyst.utils.toml import ( OperationProperties, ProgramFeatures, get_device_capabilities, + pennylane_operation_set, ) From 611e5fce8cf07a5c3b6648f9b72c94e35a378aa0 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Thu, 16 May 2024 12:31:19 +0000 Subject: [PATCH 12/21] Add missing paths module --- frontend/catalyst/utils/paths.py | 38 ++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 frontend/catalyst/utils/paths.py diff --git a/frontend/catalyst/utils/paths.py b/frontend/catalyst/utils/paths.py new file mode 100644 index 0000000000..a75b466809 --- /dev/null +++ b/frontend/catalyst/utils/paths.py @@ -0,0 +1,38 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility code for keeping paths +""" +import os +import os.path + +from catalyst._configuration import INSTALLED + +package_root = os.path.dirname(__file__) + +# Default paths to dep libraries +DEFAULT_LIB_PATHS = { + "llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"), + "runtime": os.path.join(package_root, "../../../runtime/build/lib"), + "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"), + "oqc_runtime": os.path.join(package_root, "../../catalyst/third_party/oqc/src/build"), +} + + +def get_lib_path(project, env_var): + """Get the library path.""" + if INSTALLED: + return os.path.join(package_root, "..", "lib") # pragma: no cover + return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, "")) From 5bd08cbd97915e19a1493fdebf5e896e06647c35 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Wed, 22 May 2024 09:41:22 +0000 Subject: [PATCH 13/21] Fix re-added runtime.py --- frontend/catalyst/utils/runtime.py | 264 ----------------------------- 1 file changed, 264 deletions(-) delete mode 100644 frontend/catalyst/utils/runtime.py diff --git a/frontend/catalyst/utils/runtime.py b/frontend/catalyst/utils/runtime.py deleted file mode 100644 index 8547ad3ccf..0000000000 --- a/frontend/catalyst/utils/runtime.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Runtime utility methods. -""" - -# pylint: disable=too-many-branches - -import os -import pathlib -import platform -import re -from dataclasses import dataclass -from typing import Any, Dict - -import pennylane as qml - -from catalyst._configuration import INSTALLED -from catalyst.utils.exceptions import CompileError -from catalyst.utils.toml import ( - ProgramFeatures, - TOMLDocument, - get_device_capabilities, - pennylane_operation_set, - read_toml_file, -) - -package_root = os.path.dirname(__file__) - - -# Default paths to dep libraries -DEFAULT_LIB_PATHS = { - "llvm": os.path.join(package_root, "../../../mlir/llvm-project/build/lib"), - "runtime": os.path.join(package_root, "../../../runtime/build/lib"), - "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"), - "oqc_runtime": os.path.join(package_root, "../../catalyst/third_party/oqc/src/build"), -} - - -# TODO: This should be removed after implementing `get_c_interface` -# for the following backend devices: -SUPPORTED_RT_DEVICES = { - "lightning.qubit": ("LightningSimulator", "librtd_lightning"), - "lightning.kokkos": ("LightningKokkosSimulator", "librtd_lightning"), - "braket.aws.qubit": ("OpenQasmDevice", "librtd_openqasm"), - "braket.local.qubit": ("OpenQasmDevice", "librtd_openqasm"), -} - - -def get_lib_path(project, env_var): - """Get the library path.""" - if INSTALLED: - return os.path.join(package_root, "..", "lib") # pragma: no cover - return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, "")) - - -def check_no_overlap(*args, device_name): - """Check items in *args are mutually exclusive. - - Args: - *args (List[Str]): List of strings. - device_name (str): Device name for error reporting. - - Raises: - CompileError - """ - set_of_sets = [set(arg) for arg in args] - union = set.union(*set_of_sets) - len_of_sets = [len(arg) for arg in args] - if sum(len_of_sets) == len(union): - return - - overlaps = set() - for s in set_of_sets: - overlaps.update(s - union) - union = union - s - - msg = f"Device '{device_name}' has overlapping gates: {overlaps}" - raise CompileError(msg) - - -def filter_out_adjoint(operations): - """Remove Adjoint from operations. - - Args: - operations (List[Str]): List of strings with names of supported operations - - Returns: - List: A list of strings with names of supported operations with Adjoint and C gates - removed. - """ - adjoint = re.compile(r"^Adjoint\(.*\)$") - - def is_not_adj(op): - return not re.match(adjoint, op) - - operations_no_adj = filter(is_not_adj, operations) - return set(operations_no_adj) - - -def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) -> None: - """Validate configuration document against the device attributes. - Raise CompileError in case of mismatch: - * If device is not qjit-compatible. - * If configuration file does not exists. - * If decomposable, matrix, and native gates have some overlap. - * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and - ``device.observables``. - - Args: - device (qml.Device): An instance of a quantum device. - config (TOMLDocument): A TOML document representation. - - Raises: CompileError - """ - # pylint: disable=too-many-function-args - - if not config["compilation"]["qjit_compatible"]: - raise CompileError( - f"Attempting to compile program for incompatible device '{device.name}': " - f"Config is not marked as qjit-compatible" - ) - - device_name = device.short_name if isinstance(device, qml.Device) else device.name - program_features = ProgramFeatures(device.shots is not None) - device_capabilities = get_device_capabilities(config, program_features, device_name) - - native = pennylane_operation_set(device_capabilities.native_ops) - decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops) - matrix = pennylane_operation_set(device_capabilities.to_matrix_ops) - - check_no_overlap(native, decomposable, matrix, device_name=device_name) - - if hasattr(device, "operations") and hasattr(device, "observables"): - # For gates, we require strict match - device_gates = filter_out_adjoint(set(device.operations)) - spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable)) - if device_gates != spec_gates: - raise CompileError( - "Gates in qml.device.operations and specification file do not match.\n" - f"Gates that present only in the device: {device_gates - spec_gates}\n" - f"Gates that present only in spec: {spec_gates - device_gates}\n" - ) - - # For observables, we do not have `non-native` section in the config, so we check that - # device data supercedes the specification. - device_observables = set(device.observables) - spec_observables = pennylane_operation_set(device_capabilities.native_obs) - if (spec_observables - device_observables) != set(): - raise CompileError( - "Observables in qml.device.observables and specification file do not match.\n" - f"Observables that present only in spec: {spec_observables - device_observables}\n" - ) - - -def device_get_toml_config(device) -> TOMLDocument: - """Get the contents of the device config file.""" - if hasattr(device, "config"): - # The expected case: device specifies its own config. - toml_file = device.config - else: - # TODO: Remove this section when `qml.Device`s are guaranteed to have their own config file - # field. - device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR")) - - name = device.short_name if isinstance(device, qml.Device) else device.name - # The toml files name convention we follow is to replace - # the dots with underscores in the device short name. - toml_file_name = name.replace(".", "_") + ".toml" - # And they are currently saved in the following directory. - toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name - - try: - config = read_toml_file(toml_file) - except FileNotFoundError as e: - raise CompileError( - "Attempting to compile program for incompatible device: " - f"Config file ({toml_file}) does not exist" - ) from e - - return config - - -@dataclass -class BackendInfo: - """Backend information""" - - device_name: str - c_interface_name: str - lpath: str - kwargs: Dict[str, Any] - - -def extract_backend_info(device: qml.QubitDevice, config: TOMLDocument) -> BackendInfo: - """Extract the backend info from a quantum device. The device is expected to carry a reference - to a valid TOML config file.""" - - dname = device.name - if isinstance(device, qml.Device): - dname = device.short_name - - device_name = "" - device_lpath = "" - device_kwargs = {} - - if dname in SUPPORTED_RT_DEVICES: - # Support backend devices without `get_c_interface` - device_name = SUPPORTED_RT_DEVICES[dname][0] - device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR") - sys_platform = platform.system() - - if sys_platform == "Linux": - device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so") - elif sys_platform == "Darwin": # pragma: no cover - device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib") - else: # pragma: no cover - raise NotImplementedError(f"Platform not supported: {sys_platform}") - elif hasattr(device, "get_c_interface"): - # Support third party devices with `get_c_interface` - device_name, device_lpath = device.get_c_interface() - else: - raise CompileError(f"The {dname} device does not provide C interface for compilation.") - - if not pathlib.Path(device_lpath).is_file(): - raise CompileError(f"Device at {device_lpath} cannot be found!") - - if hasattr(device, "shots"): - if isinstance(device, qml.Device): - device_kwargs["shots"] = device.shots if device.shots else 0 - else: - # TODO: support shot vectors - device_kwargs["shots"] = device.shots.total_shots if device.shots else 0 - - if dname == "braket.local.qubit": # pragma: no cover - device_kwargs["device_type"] = dname - device_kwargs["backend"] = ( - # pylint: disable=protected-access - device._device._delegate.DEVICE_ID - ) - elif dname == "braket.aws.qubit": # pragma: no cover - device_kwargs["device_type"] = dname - device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access - if device._s3_folder: # pylint: disable=protected-access - device_kwargs["s3_destination_folder"] = str( - device._s3_folder # pylint: disable=protected-access - ) - - options = config.get("options", {}) - for k, v in options.items(): - if hasattr(device, v): - device_kwargs[k] = getattr(device, v) - - return BackendInfo(dname, device_name, device_lpath, device_kwargs) From f59e6bcdb07c18d918f74f0cbbf2790b7749aa32 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Wed, 22 May 2024 10:20:14 +0000 Subject: [PATCH 14/21] Fix a test --- frontend/test/pytest/test_device_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index d31bae2bcf..12b8a43010 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -129,7 +129,7 @@ def test_qjit_device_no_wires(): with pytest.raises( AttributeError, match="Catalyst does not support devices without set wires." ): - QJITDeviceNewAPI(device, backend_info) + QJITDeviceNewAPI(device, capabilities, backend_info) def test_simple_circuit(): From 1ba91492039bb439e6d93b10f4d43b28cc55855a Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 24 May 2024 13:59:53 +0400 Subject: [PATCH 15/21] Update frontend/catalyst/jax_primitives.py Co-authored-by: David Ittah --- frontend/catalyst/jax_primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 2044a8fcd6..7d50fea3d3 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -76,7 +76,7 @@ from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp from catalyst.utils.types import convert_shaped_arrays_to_tensors -# pylint: disable=unused-argument,too-many-lines +# pylint: disable=unused-argument,abstract-method,too-many-lines ######### # Types # From 4735b857c27f68f8e5eeff9e7bfe44c1da18a49f Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 24 May 2024 14:00:06 +0400 Subject: [PATCH 16/21] Update frontend/catalyst/third_party/cuda/primitives/__init__.py Co-authored-by: David Ittah --- frontend/catalyst/third_party/cuda/primitives/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/third_party/cuda/primitives/__init__.py b/frontend/catalyst/third_party/cuda/primitives/__init__.py index e1679d3bcb..ead6053b29 100644 --- a/frontend/catalyst/third_party/cuda/primitives/__init__.py +++ b/frontend/catalyst/third_party/cuda/primitives/__init__.py @@ -24,7 +24,7 @@ # We disable protected access in particular to avoid warnings with cudaq._pycuda. # And we disable unused-argument to avoid unused arguments in abstract_eval, particularly kwargs. -# pylint: disable=protected-access,unused-argument,line-too-long +# pylint: disable=protected-access,unused-argument,abstract-method,line-too-long class AbsCudaQState(jax.core.AbstractValue): From 74ec4ffdf04be4eb77077767fe05f48c385f3494 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 24 May 2024 11:13:18 +0000 Subject: [PATCH 17/21] Rename paths -> runtime_environment --- frontend/catalyst/compiler.py | 2 +- frontend/catalyst/device/qjit_device.py | 2 +- frontend/catalyst/utils/{paths.py => runtime_environment.py} | 0 frontend/catalyst/utils/toml.py | 2 +- frontend/test/pytest/test_debug.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename frontend/catalyst/utils/{paths.py => runtime_environment.py} (100%) diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 798b0b6cd9..b983ecb388 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -33,7 +33,7 @@ from catalyst.utils.exceptions import CompileError from catalyst.utils.filesystem import Directory -from catalyst.utils.paths import get_lib_path +from catalyst.utils.runtime_environment import get_lib_path package_root = os.path.dirname(__file__) diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index a5d0b1a203..4bdbb65ee9 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -35,7 +35,7 @@ ) from catalyst.utils.exceptions import CompileError from catalyst.utils.patching import Patcher -from catalyst.utils.paths import get_lib_path +from catalyst.utils.runtime_environment import get_lib_path from catalyst.utils.toml import ( DeviceCapabilities, OperationProperties, diff --git a/frontend/catalyst/utils/paths.py b/frontend/catalyst/utils/runtime_environment.py similarity index 100% rename from frontend/catalyst/utils/paths.py rename to frontend/catalyst/utils/runtime_environment.py diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py index 899eb7b7c7..51a84d95ce 100644 --- a/frontend/catalyst/utils/toml.py +++ b/frontend/catalyst/utils/toml.py @@ -26,7 +26,7 @@ import pennylane as qml from catalyst.utils.exceptions import CompileError -from catalyst.utils.paths import get_lib_path +from catalyst.utils.runtime_environment import get_lib_path # TODO: # Once Python version 3.11 is the oldest supported Python version, we can remove tomlkit diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py index 69c113ac00..d9d70eb67e 100644 --- a/frontend/test/pytest/test_debug.py +++ b/frontend/test/pytest/test_debug.py @@ -22,7 +22,7 @@ from catalyst.compiler import CompileOptions, Compiler from catalyst.debug import compile_from_mlir, get_cmain, print_compilation_stage from catalyst.utils.exceptions import CompileError -from catalyst.utils.paths import get_lib_path +from catalyst.utils.runtime_environment import get_lib_path class TestDebugPrint: From 8793214ec48ce3153c5dfc5ddd2d831477df848e Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 24 May 2024 11:20:34 +0000 Subject: [PATCH 18/21] Address review suggestions: move validate_device_requirements -> qjit_device.py --- frontend/catalyst/device/__init__.py | 9 +- frontend/catalyst/device/qjit_device.py | 99 +++++++++++++++++++ frontend/catalyst/qfunc.py | 2 +- frontend/catalyst/utils/toml.py | 99 ------------------- frontend/test/pytest/test_config_functions.py | 3 +- 5 files changed, 109 insertions(+), 103 deletions(-) diff --git a/frontend/catalyst/device/__init__.py b/frontend/catalyst/device/__init__.py index 9835db81ce..a830cf7ec4 100644 --- a/frontend/catalyst/device/__init__.py +++ b/frontend/catalyst/device/__init__.py @@ -21,6 +21,13 @@ QJITDevice, QJITDeviceNewAPI, extract_backend_info, + validate_device_capabilities, ) -__all__ = ("QJITDevice", "QJITDeviceNewAPI", "BackendInfo", "extract_backend_info") +__all__ = ( + "QJITDevice", + "QJITDeviceNewAPI", + "BackendInfo", + "extract_backend_info", + "validate_device_capabilities", +) diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index 4bdbb65ee9..24be391cdc 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -19,6 +19,7 @@ import os import pathlib import platform +import re from copy import deepcopy from dataclasses import dataclass from functools import partial @@ -405,3 +406,101 @@ def execute(self, circuits, execution_config): Raises: RuntimeError """ raise RuntimeError("QJIT devices cannot execute tapes.") + + +def filter_out_adjoint(operations): + """Remove Adjoint from operations. + + Args: + operations (List[Str]): List of strings with names of supported operations + + Returns: + List: A list of strings with names of supported operations with Adjoint and C gates + removed. + """ + adjoint = re.compile(r"^Adjoint\(.*\)$") + + def is_not_adj(op): + return not re.match(adjoint, op) + + operations_no_adj = filter(is_not_adj, operations) + return set(operations_no_adj) + + +def check_no_overlap(*args, device_name): + """Check items in *args are mutually exclusive. + + Args: + *args (List[Str]): List of strings. + device_name (str): Device name for error reporting. + + Raises: + CompileError + """ + set_of_sets = [set(arg) for arg in args] + union = set.union(*set_of_sets) + len_of_sets = [len(arg) for arg in args] + if sum(len_of_sets) == len(union): + return + + overlaps = set() + for s in set_of_sets: + overlaps.update(s - union) + union = union - s + + msg = f"Device '{device_name}' has overlapping gates: {overlaps}" + raise CompileError(msg) + + +def validate_device_capabilities( + device: qml.QubitDevice, device_capabilities: DeviceCapabilities +) -> None: + """Validate configuration document against the device attributes. + Raise CompileError in case of mismatch: + * If device is not qjit-compatible. + * If configuration file does not exists. + * If decomposable, matrix, and native gates have some overlap. + * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and + ``device.observables``. + + Args: + device (qml.Device): An instance of a quantum device. + config (TOMLDocument): A TOML document representation. + + Raises: CompileError + """ + + if not device_capabilities.qjit_compatible_flag: + raise CompileError( + f"Attempting to compile program for incompatible device '{device.name}': " + f"Config is not marked as qjit-compatible" + ) + + device_name = device.short_name if isinstance(device, qml.Device) else device.name + + native = pennylane_operation_set(device_capabilities.native_ops) + decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops) + matrix = pennylane_operation_set(device_capabilities.to_matrix_ops) + + check_no_overlap(native, decomposable, matrix, device_name=device_name) + + if hasattr(device, "operations") and hasattr(device, "observables"): + # For gates, we require strict match + device_gates = filter_out_adjoint(set(device.operations)) + spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable)) + if device_gates != spec_gates: + raise CompileError( + "Gates in qml.device.operations and specification file do not match.\n" + f"Gates that present only in the device: {device_gates - spec_gates}\n" + f"Gates that present only in spec: {spec_gates - device_gates}\n" + ) + + # For observables, we do not have `non-native` section in the config, so we check that + # device data supercedes the specification. + device_observables = set(device.observables) + spec_observables = pennylane_operation_set(device_capabilities.native_obs) + if (spec_observables - device_observables) != set(): + raise CompileError( + "Observables in qml.device.observables and specification file do not match.\n" + f"Observables that present only in spec: {spec_observables - device_observables}\n" + ) diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index 2ce2bd9481..67b49a3940 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -27,6 +27,7 @@ QJITDevice, QJITDeviceNewAPI, extract_backend_info, + validate_device_capabilities, ) from catalyst.jax_extras import ( deduce_avals, @@ -39,7 +40,6 @@ DeviceCapabilities, ProgramFeatures, get_device_capabilities, - validate_device_capabilities, ) diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py index 51a84d95ce..1806d23c2c 100644 --- a/frontend/catalyst/utils/toml.py +++ b/frontend/catalyst/utils/toml.py @@ -17,7 +17,6 @@ import importlib.util import pathlib -import re from dataclasses import dataclass from functools import reduce from itertools import repeat @@ -430,101 +429,3 @@ def load_device_capabilities( dynamic_qubit_management_flag=get_compilation_flag(config, "dynamic_qubit_management"), options=get_options(config), ) - - -def check_no_overlap(*args, device_name): - """Check items in *args are mutually exclusive. - - Args: - *args (List[Str]): List of strings. - device_name (str): Device name for error reporting. - - Raises: - CompileError - """ - set_of_sets = [set(arg) for arg in args] - union = set.union(*set_of_sets) - len_of_sets = [len(arg) for arg in args] - if sum(len_of_sets) == len(union): - return - - overlaps = set() - for s in set_of_sets: - overlaps.update(s - union) - union = union - s - - msg = f"Device '{device_name}' has overlapping gates: {overlaps}" - raise CompileError(msg) - - -def filter_out_adjoint(operations): - """Remove Adjoint from operations. - - Args: - operations (List[Str]): List of strings with names of supported operations - - Returns: - List: A list of strings with names of supported operations with Adjoint and C gates - removed. - """ - adjoint = re.compile(r"^Adjoint\(.*\)$") - - def is_not_adj(op): - return not re.match(adjoint, op) - - operations_no_adj = filter(is_not_adj, operations) - return set(operations_no_adj) - - -def validate_device_capabilities( - device: qml.QubitDevice, device_capabilities: DeviceCapabilities -) -> None: - """Validate configuration document against the device attributes. - Raise CompileError in case of mismatch: - * If device is not qjit-compatible. - * If configuration file does not exists. - * If decomposable, matrix, and native gates have some overlap. - * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and - ``device.observables``. - - Args: - device (qml.Device): An instance of a quantum device. - config (TOMLDocument): A TOML document representation. - - Raises: CompileError - """ - - if not device_capabilities.qjit_compatible_flag: - raise CompileError( - f"Attempting to compile program for incompatible device '{device.name}': " - f"Config is not marked as qjit-compatible" - ) - - device_name = device.short_name if isinstance(device, qml.Device) else device.name - - native = pennylane_operation_set(device_capabilities.native_ops) - decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops) - matrix = pennylane_operation_set(device_capabilities.to_matrix_ops) - - check_no_overlap(native, decomposable, matrix, device_name=device_name) - - if hasattr(device, "operations") and hasattr(device, "observables"): - # For gates, we require strict match - device_gates = filter_out_adjoint(set(device.operations)) - spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable)) - if device_gates != spec_gates: - raise CompileError( - "Gates in qml.device.operations and specification file do not match.\n" - f"Gates that present only in the device: {device_gates - spec_gates}\n" - f"Gates that present only in spec: {spec_gates - device_gates}\n" - ) - - # For observables, we do not have `non-native` section in the config, so we check that - # device data supercedes the specification. - device_observables = set(device.observables) - spec_observables = pennylane_operation_set(device_capabilities.native_obs) - if (spec_observables - device_observables) != set(): - raise CompileError( - "Observables in qml.device.observables and specification file do not match.\n" - f"Observables that present only in spec: {spec_observables - device_observables}\n" - ) diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py index 7a7bc0f8c1..992b42314c 100644 --- a/frontend/test/pytest/test_config_functions.py +++ b/frontend/test/pytest/test_config_functions.py @@ -21,7 +21,7 @@ import pennylane as qml import pytest -from catalyst.device import QJITDevice +from catalyst.device import QJITDevice, validate_device_capabilities from catalyst.utils.exceptions import CompileError from catalyst.utils.toml import ( DeviceCapabilities, @@ -35,7 +35,6 @@ load_device_capabilities, pennylane_operation_set, read_toml_file, - validate_device_capabilities, ) From 04a5b41a04a90b496d53b50a698a56c13e3045b9 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 24 May 2024 11:33:30 +0000 Subject: [PATCH 19/21] Address review suggestions: remove self.qjit_device attribute from qnode; Add a todo notice --- frontend/catalyst/qfunc.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index 67b49a3940..045d947c44 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -68,24 +68,25 @@ def extract_backend_info( def __call__(self, *args, **kwargs): assert isinstance(self, qml.QNode) - device = self.device - program_features = ProgramFeatures(device.shots is not None) - device_capabilities = get_device_capabilities(device, program_features) - backend_info = QFunc.extract_backend_info(device, device_capabilities) + # TODO: Move the capability loading and validation to the device constructor when the + # support for old device api is dropped. + program_features = ProgramFeatures(self.device.shots is not None) + device_capabilities = get_device_capabilities(self.device, program_features) + backend_info = QFunc.extract_backend_info(self.device, device_capabilities) # Validate decive operations against the declared capabilities - validate_device_capabilities(device, device_capabilities) + validate_device_capabilities(self.device, device_capabilities) if isinstance(self.device, qml.devices.Device): - self.qjit_device = QJITDeviceNewAPI(device, device_capabilities, backend_info) + qjit_device = QJITDeviceNewAPI(self.device, device_capabilities, backend_info) else: - self.qjit_device = QJITDevice( - device_capabilities, device.shots, device.wires, backend_info + qjit_device = QJITDevice( + device_capabilities, self.device.shots, self.device.wires, backend_info ) def _eval_quantum(*args): closed_jaxpr, out_type, out_tree = trace_quantum_function( - self.func, self.qjit_device, args, kwargs, qnode=self + self.func, qjit_device, args, kwargs, qnode=self ) args_expanded = get_implicit_and_explicit_flat_args(None, *args) res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded) From 2bb3ede8507574ec026b178421513035a95d0c58 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 24 May 2024 11:46:27 +0000 Subject: [PATCH 20/21] Fix a test import --- frontend/test/pytest/test_config_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py index 992b42314c..50488da984 100644 --- a/frontend/test/pytest/test_config_functions.py +++ b/frontend/test/pytest/test_config_functions.py @@ -22,12 +22,12 @@ import pytest from catalyst.device import QJITDevice, validate_device_capabilities +from catalyst.device.qjit_device import check_no_overlap from catalyst.utils.exceptions import CompileError from catalyst.utils.toml import ( DeviceCapabilities, ProgramFeatures, TOMLDocument, - check_no_overlap, check_quantum_control_flag, get_decomposable_gates, get_matrix_decomposable_gates, From eadeb91f0b29223a1cae49094c957a2df547032c Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 24 May 2024 13:56:37 +0000 Subject: [PATCH 21/21] Update changelog --- doc/changelog.md | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/doc/changelog.md b/doc/changelog.md index db3069494f..af5c2dcde7 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -92,11 +92,14 @@ annotations. [(#751)](https://github.com/PennyLaneAI/catalyst/pull/751) -* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable - class that implements the decorator's logic. This prevents having to excessively define +* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable + class that implements the decorator's logic. This prevents having to excessively define functions in a nested fashion. [(#758)](https://github.com/PennyLaneAI/catalyst/pull/758) +* Catalyst tests now manipulate device capabilities rather than text configurations files. + [(#712)](https://github.com/PennyLaneAI/catalyst/pull/712) +

Breaking changes

* Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`. @@ -198,11 +201,12 @@ This release contains contributions from (in alphabetical order): David Ittah, -Mehrdad Malekmohammadi, Erick Ochoa, +Haochen Paul Wang, Lee James O'Riordan, +Mehrdad Malekmohammadi, Raul Torres, -Haochen Paul Wang. +Sergei Mironov. # Release 0.6.0