diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index 09a1a921c3..813d06224e 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -67,7 +67,7 @@ def __call__(self, *args, **kwargs): backend_info = QFunc.extract_backend_info(self.device, config) if isinstance(self.device, qml.devices.Device): - device = QJITDeviceNewAPI(self.device, config, backend_info) + device = QJITDeviceNewAPI(self.device, backend_info) else: device = QJITDevice(config, self.device.shots, self.device.wires, backend_info) diff --git a/frontend/catalyst/qjit_device.py b/frontend/catalyst/qjit_device.py index b798a03031..f025ffe9c4 100644 --- a/frontend/catalyst/qjit_device.py +++ b/frontend/catalyst/qjit_device.py @@ -13,6 +13,7 @@ # limitations under the License. """This module contains the qjit device classes. """ +from copy import deepcopy from functools import partial from typing import Optional, Set @@ -23,95 +24,101 @@ from catalyst.preprocess import catalyst_acceptance, decompose from catalyst.utils.exceptions import CompileError from catalyst.utils.patching import Patcher -from catalyst.utils.runtime import ( - BackendInfo, - get_pennylane_observables, - get_pennylane_operations, -) +from catalyst.utils.runtime import BackendInfo, device_get_toml_config from catalyst.utils.toml import ( + DeviceCapabilities, + OperationProperties, + ProgramFeatures, TOMLDocument, - check_adjoint_flag, - check_mid_circuit_measurement_flag, + get_device_capabilities, + intersect_operations, + pennylane_operation_set, ) +# fmt:off RUNTIME_OPERATIONS = { - "CNOT", - "ControlledPhaseShift", - "CRot", - "CRX", - "CRY", - "CRZ", - "CSWAP", - "CY", - "CZ", - "Hadamard", - "Identity", - "IsingXX", - "IsingXY", - "IsingYY", - "ISWAP", - "MultiRZ", - "PauliX", - "PauliY", - "PauliZ", - "PhaseShift", - "PSWAP", - "QubitUnitary", - "Rot", - "RX", - "RY", - "RZ", - "S", - "SWAP", - "T", - "Toffoli", - "GlobalPhase", - "C(GlobalPhase)", - "C(Hadamard)", - "C(IsingXX)", - "C(IsingXY)", - "C(IsingYY)", - "C(ISWAP)", - "C(MultiRZ)", - "ControlledQubitUnitary", - "C(PauliX)", - "C(PauliY)", - "C(PauliZ)", - "C(PhaseShift)", - "C(PSWAP)", - "C(Rot)", - "C(RX)", - "C(RY)", - "C(RZ)", - "C(S)", - "C(SWAP)", - "C(T)", + 'CNOT': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'ControlledPhaseShift': + OperationProperties(invertible=True, controllable=True, differentiable=True), + 'CRot': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'CRX': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'CRY': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'CRZ': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'CSWAP': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'CY': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'CZ': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'Hadamard': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'Identity': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'IsingXX': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'IsingXY': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'IsingYY': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'ISWAP': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'MultiRZ': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'PauliX': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'PauliY': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'PauliZ': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'PhaseShift': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'PSWAP': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'QubitUnitary': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'ControlledQubitUnitary': + OperationProperties(invertible=True, controllable=True, differentiable=True), + 'Rot': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'RX': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'RY': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'RZ': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'S': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'SWAP': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'T': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'Toffoli': OperationProperties(invertible=True, controllable=True, differentiable=True), + 'GlobalPhase': OperationProperties(invertible=True, controllable=True, differentiable=True), } +# fmt:on -def get_qjit_pennylane_operations( - config: TOMLDocument, shots_present: bool, device_name: str -) -> Set[str]: +def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Set[str]: """Calculate the set of supported quantum gates for the QJIT device from the gates allowed on the target quantum device.""" # Supported gates of the target PennyLane's device - native_gates = get_pennylane_operations(config, shots_present, device_name) + qjit_config = deepcopy(target_capabilities) + # Gates that Catalyst runtime supports qir_gates = RUNTIME_OPERATIONS - supported_gates = set.intersection(native_gates, qir_gates) + + # Intersection of the above + qjit_config.native_ops = intersect_operations(target_capabilities.native_ops, qir_gates) # Control-flow gates to be lowered down to the LLVM control-flow instructions - supported_gates.update({"Cond", "WhileLoop", "ForLoop"}) + qjit_config.native_ops.update( + { + "Cond": OperationProperties(invertible=True, controllable=True, differentiable=True), + "WhileLoop": OperationProperties( + invertible=True, controllable=True, differentiable=True + ), + "ForLoop": OperationProperties(invertible=True, controllable=True, differentiable=True), + } + ) # Optionally enable runtime-powered mid-circuit measurments - if check_mid_circuit_measurement_flag(config): # pragma: no branch - supported_gates.update({"MidCircuitMeasure"}) + if target_capabilities.mid_circuit_measurement_flag: # pragma: no branch + qjit_config.native_ops.update( + { + "MidCircuitMeasure": OperationProperties( + invertible=True, controllable=True, differentiable=True + ) + } + ) # Optionally enable runtime-powered quantum gate adjointing (inversions) - if check_adjoint_flag(config, shots_present): - supported_gates.update({"Adjoint"}) + if all(ng.invertible for ng in target_capabilities.native_ops.values()): + qjit_config.native_ops.update( + { + "Adjoint": OperationProperties( + invertible=True, controllable=True, differentiable=True + ) + } + ) - return supported_gates + return qjit_config class QJITDevice(qml.QubitDevice): @@ -137,7 +144,7 @@ class QJITDevice(qml.QubitDevice): author = "" @staticmethod - def _get_operations_to_convert_to_matrix(_config: TOMLDocument) -> Set[str]: + def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> Set[str]: # We currently override and only set a few gates to preserve existing behaviour. # We could choose to read from config and use the "matrix" gates. # However, that affects differentiability. @@ -154,25 +161,26 @@ def __init__( ): super().__init__(wires=wires, shots=shots) - self.target_config = target_config self.backend_name = backend.c_interface_name if backend else "default" self.backend_lib = backend.lpath if backend else "" self.backend_kwargs = backend.kwargs if backend else {} device_name = backend.device_name if backend else "default" - shots_present = shots is not None - self._operations = get_qjit_pennylane_operations(target_config, shots_present, device_name) - self._observables = get_pennylane_observables(target_config, shots_present, device_name) + program_features = ProgramFeatures(shots is not None) + target_device_capabilities = get_device_capabilities( + target_config, program_features, device_name + ) + self.capabilities = get_qjit_device_capabilities(target_device_capabilities) @property def operations(self) -> Set[str]: - """Get the device operations""" - return self._operations + """Get the device operations using PennyLane's syntax""" + return pennylane_operation_set(self.capabilities.native_ops) @property def observables(self) -> Set[str]: """Get the device observables""" - return self._observables + return pennylane_operation_set(self.capabilities.native_obs) def apply(self, operations, **kwargs): """ @@ -202,7 +210,7 @@ def default_expand_fn(self, circuit, max_expansion=10): raise CompileError("Must use 'measure' from Catalyst instead of PennyLane.") decompose_to_qubit_unitary = QJITDevice._get_operations_to_convert_to_matrix( - self.target_config + self.capabilities ) def _decomp_to_unitary(self, *_args, **_kwargs): @@ -251,7 +259,6 @@ class QJITDeviceNewAPI(qml.devices.Device): def __init__( self, original_device, - target_config: TOMLDocument, backend: Optional[BackendInfo] = None, ): self.original_device = original_device @@ -264,25 +271,27 @@ def __init__( super().__init__(wires=original_device.wires, shots=original_device.shots) - self.target_config = target_config self.backend_name = backend.c_interface_name if backend else "default" self.backend_lib = backend.lpath if backend else "" self.backend_kwargs = backend.kwargs if backend else {} device_name = backend.device_name if backend else "default" - shots_present = original_device.shots is not None - self._operations = get_qjit_pennylane_operations(target_config, shots_present, device_name) - self._observables = get_pennylane_observables(target_config, shots_present, device_name) + target_config = device_get_toml_config(original_device) + program_features = ProgramFeatures(original_device.shots is not None) + target_device_capabilities = get_device_capabilities( + target_config, program_features, device_name + ) + self.capabilities = get_qjit_device_capabilities(target_device_capabilities) @property def operations(self) -> Set[str]: """Get the device operations""" - return self._operations + return pennylane_operation_set(self.capabilities.native_ops) @property def observables(self) -> Set[str]: """Get the device observables""" - return self._observables + return pennylane_operation_set(self.capabilities.native_obs) def preprocess( self, diff --git a/frontend/catalyst/utils/runtime.py b/frontend/catalyst/utils/runtime.py index 8169249ffc..2d1dc0e8e6 100644 --- a/frontend/catalyst/utils/runtime.py +++ b/frontend/catalyst/utils/runtime.py @@ -22,20 +22,17 @@ import platform import re from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, Set +from typing import Any, Dict import pennylane as qml from catalyst._configuration import INSTALLED from catalyst.utils.exceptions import CompileError from catalyst.utils.toml import ( + ProgramFeatures, TOMLDocument, - check_quantum_control_flag, - get_decomposable_gates, - get_matrix_decomposable_gates, - get_native_gates, - get_observables, + get_device_capabilities, + pennylane_operation_set, read_toml_file, ) @@ -68,95 +65,12 @@ def get_lib_path(project, env_var): return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, "")) -def deduce_schema1_native_controlled_gates(native_gates: Set[str]) -> Set[str]: - """Calculate the set of controlled gates given the set of natively supported gates. This - function is used with the toml config schema 1 which did not support per-gate control - specifications. Later schemas provide the required information directly. - """ - # The deduction logic is the following: - # * Most of the gates have their `C(Gate)` controlled counterparts. - # * Some gates have to be decomposed if controlled version is used. Typically these are gates - # which are already controlled but have well-known names. - # * Few gates, like `QubitUnitary`, have separate classes for their controlled versions. - gates_to_be_decomposed_if_controlled = [ - "Identity", - "CNOT", - "CY", - "CZ", - "CSWAP", - "CRX", - "CRY", - "CRZ", - "CRot", - "ControlledPhaseShift", - "QubitUnitary", - "Toffoli", - ] - native_controlled_gates = set( - [f"C({gate})" for gate in native_gates if gate not in gates_to_be_decomposed_if_controlled] - + [f"Controlled{gate}" for gate in native_gates if gate in ["QubitUnitary"]] - ) - return native_controlled_gates - - -def get_pennylane_operations( - config: TOMLDocument, shots_present: bool, device_name: str -) -> Set[str]: - """Get gates that are natively supported by the device and therefore do not need to be - decomposed. - - Args: - config (Dict[Str, Any]): Configuration dictionary - shots_present (bool): True is exact shots is specified in the current top-level program - device_name (str): Name of quantum device. Used for ad-hoc patching. - - Returns: - Set[str]: List of gate names in the PennyLane format. - """ - gates_PL = set() - schema = int(config["schema"]) - - if schema == 1: - native_gates_attrs = get_native_gates(config, shots_present) - assert all(len(v) == 0 for v in native_gates_attrs.values()) - native_gates = set(native_gates_attrs) - supports_controlled = check_quantum_control_flag(config) - native_controlled_gates = ( - deduce_schema1_native_controlled_gates(native_gates) if supports_controlled else set() - ) - - # TODO: remove after PR #642 is merged in lightning - if device_name == "lightning.kokkos": # pragma: nocover - native_gates.update({"C(GlobalPhase)"}) - - gates_PL = set.union(native_gates, native_controlled_gates) - - elif schema == 2: - native_gates = get_native_gates(config, shots_present) - for gate, attrs in native_gates.items(): - gates_PL.add(f"{gate}") - if "controllable" in attrs.get("properties", {}): - gates_PL.add(f"C({gate})") - - else: - raise CompileError("Device configuration schema {schema} is not supported") - - return gates_PL - - -def get_pennylane_observables( - config: TOMLDocument, shots_present: bool, _device_name: str -) -> Set[str]: - """Get observables in PennyLane format. Apply ad-hoc patching""" - - return set(get_observables(config, shots_present)) - - -def check_no_overlap(*args): +def check_no_overlap(*args, device_name): """Check items in *args are mutually exclusive. Args: *args (List[Str]): List of strings. + device_name (str): Device name for error reporting. Raises: CompileError @@ -172,7 +86,7 @@ def check_no_overlap(*args): overlaps.update(s - union) union = union - s - msg = f"Device has overlapping gates: {overlaps}" + msg = f"Device '{device_name}' has overlapping gates: {overlaps}" raise CompileError(msg) @@ -218,23 +132,14 @@ def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) - ) device_name = device.short_name if isinstance(device, qml.Device) else device.name + program_features = ProgramFeatures(device.shots is not None) + device_capabilities = get_device_capabilities(config, program_features, device_name) - shots_present = device.shots is not None - native = get_pennylane_operations(config, shots_present, device_name) - decomposable = set(get_decomposable_gates(config, shots_present)) - matrix = set(get_matrix_decomposable_gates(config, shots_present)) - - # For toml schema 1 configs, the following condition is possible: (1) `QubitUnitary` gate is - # supported, (2) native quantum control flag is enabled and (3) `ControlledQubitUnitary` is - # listed in either matrix or decomposable sections. This is a contradiction, because condition - # (1) means that `ControlledQubitUnitary` is also in the native set. We solve it here by - # applying a fixup. - # TODO: remove after PR #642 is merged in lightning - if "ControlledQubitUnitary" in native: - matrix = matrix - {"ControlledQubitUnitary"} - decomposable = decomposable - {"ControlledQubitUnitary"} + native = pennylane_operation_set(device_capabilities.native_ops) + decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops) + matrix = pennylane_operation_set(device_capabilities.to_matrix_ops) - check_no_overlap(native, decomposable, matrix) + check_no_overlap(native, decomposable, matrix, device_name=device_name) if hasattr(device, "operations") and hasattr(device, "observables"): # For gates, we require strict match @@ -250,7 +155,7 @@ def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) - # For observables, we do not have `non-native` section in the config, so we check that # device data supercedes the specification. device_observables = set(device.observables) - spec_observables = get_pennylane_observables(config, shots_present, device_name) + spec_observables = pennylane_operation_set(device_capabilities.native_obs) if (spec_observables - device_observables) != set(): raise CompileError( "Observables in qml.device.observables and specification file do not match.\n" @@ -258,8 +163,8 @@ def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) - ) -def device_get_toml_config(device) -> Path: - """Get the path of the device config file.""" +def device_get_toml_config(device) -> TOMLDocument: + """Get the contents of the device config file.""" if hasattr(device, "config"): # The expected case: device specifies its own config. toml_file = device.config diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py index ce353465f8..9c47547c8a 100644 --- a/frontend/catalyst/utils/toml.py +++ b/frontend/catalyst/utils/toml.py @@ -16,9 +16,10 @@ """ import importlib.util +from dataclasses import dataclass from functools import reduce from itertools import repeat -from typing import Any, Dict, List +from typing import Any, Dict, List, Set from catalyst.utils.exceptions import CompileError @@ -40,9 +41,11 @@ from tomllib import load as toml_load TOMLDocument = Any + TOMLException = Exception else: # pragma: nocover from tomlkit import TOMLDocument from tomlkit import load as toml_load + from tomlkit.exceptions import TOMLKitError as TOMLException def read_toml_file(toml_file: str) -> TOMLDocument: @@ -52,25 +55,65 @@ def read_toml_file(toml_file: str) -> TOMLDocument: return config -def check_mid_circuit_measurement_flag(config: TOMLDocument) -> bool: - """Check the global mid-circuit measurement flag""" - return bool(config.get("compilation", {}).get("mid_circuit_measurement", False)) +@dataclass +class OperationProperties: + """Capabilities of a single operation""" + invertible: bool + controllable: bool + differentiable: bool -def check_adjoint_flag(config: TOMLDocument, shots_present: bool) -> bool: - """Check the global adjoint flag for toml schema 1. For newer schemas the adjoint flag is - defined to be set if all native gates are inverible""" - schema = int(config["schema"]) - if schema == 1: - return bool(config.get("compilation", {}).get("quantum_adjoint", False)) - elif schema == 2: - return all( - "invertible" in v.get("properties", {}) - for g, v in get_native_gates(config, shots_present).items() - ) +def intersect_properties(a: OperationProperties, b: OperationProperties) -> OperationProperties: + """Calculate the intersection of OperationProperties""" + return OperationProperties( + invertible=a.invertible and b.invertible, + controllable=a.controllable and b.controllable, + differentiable=a.differentiable and b.differentiable, + ) + + +@dataclass +class DeviceCapabilities: + """Quantum device capabilities""" + + native_ops: Dict[str, OperationProperties] + to_decomp_ops: Dict[str, OperationProperties] + to_matrix_ops: Dict[str, OperationProperties] + native_obs: Dict[str, OperationProperties] + mid_circuit_measurement_flag: bool + runtime_code_generation_flag: bool + dynamic_qubit_management_flag: bool + + +def intersect_operations( + a: Dict[str, OperationProperties], b: Dict[str, OperationProperties] +) -> Dict[str, OperationProperties]: + """Intersects two sets of oepration properties""" + return {k: intersect_properties(a[k], b[k]) for k in (a.keys() & b.keys())} - raise CompileError("quantum_adjoint flag is not supported in TOMLs schema >= 3") + +def pennylane_operation_set(config_ops: Dict[str, OperationProperties]) -> Set[str]: + """Returns a config section into a set of strings using PennyLane syntax""" + ops = set() + # Back-mapping from class names to string names + for g, props in config_ops.items(): + ops.update({g}) + if props.controllable: + ops.update({f"C({g})"}) + return ops + + +@dataclass +class ProgramFeatures: + """Program features, obtained from the user""" + + shots_present: bool + + +def check_compilation_flag(config: TOMLDocument, flag_name: str) -> bool: + """Checks the flag in the toml document 'compilation' section.""" + return bool(config.get("compilation", {}).get(flag_name, False)) def check_quantum_control_flag(config: TOMLDocument) -> bool: @@ -82,13 +125,18 @@ def check_quantum_control_flag(config: TOMLDocument) -> bool: raise CompileError("quantum_control flag is not supported in TOMLs schema >= 2") -def get_gates(config: TOMLDocument, path: List[str], shots_present: bool) -> Dict[str, dict]: - """Read the toml config section specified by `path`. Filters-out gates which don't match - condition. For now the only condition we support is `shots_present`.""" +def parse_toml_section( + config: TOMLDocument, path: List[str], program_features: ProgramFeatures +) -> Dict[str, dict]: + """Parses the section of toml config file specified by `path`. Filters-out gates which don't + match condition. For now the only condition we support is `shots_present`.""" gates = {} analytic = "analytic" finiteshots = "finiteshots" - iterable = reduce(lambda x, y: x[y], path, config) + try: + iterable = reduce(lambda x, y: x[y], path, config) + except TOMLException as _: # pylint: disable=broad-exception-caught + return {} gen = iterable.items() if hasattr(iterable, "items") else zip(iterable, repeat({})) for g, values in gen: unknown_attrs = set(values) - {"condition", "properties"} @@ -103,6 +151,7 @@ def get_gates(config: TOMLDocument, path: List[str], shots_present: bool) -> Dic f"Configuration for gate '{str(g)}' has unknown properties: {list(unknown_props)}" ) if "condition" in values: + # TODO: do not filter here. Parse the condition and then filter on demand instead. conditions = values["condition"] unknown_conditions = set(conditions) - {analytic, finiteshots} if len(unknown_conditions) > 0: @@ -115,33 +164,35 @@ def get_gates(config: TOMLDocument, path: List[str], shots_present: bool) -> Dic f"Configuration for gate '{g}' can not contain both " f"`{finiteshots}` and `{analytic}` conditions simultaniosly" ) - if analytic in conditions and not shots_present: + if analytic in conditions and not program_features.shots_present: gates[g] = values - elif finiteshots in conditions and shots_present: + elif finiteshots in conditions and program_features.shots_present: gates[g] = values else: gates[g] = values return gates -def get_observables(config: TOMLDocument, shots_present: bool) -> Dict[str, dict]: +def get_observables(config: TOMLDocument, program_features: ProgramFeatures) -> Dict[str, dict]: """Override the set of supported observables.""" - return get_gates(config, ["operators", "observables"], shots_present) + return parse_toml_section(config, ["operators", "observables"], program_features) -def get_native_gates(config: TOMLDocument, shots_present: bool) -> Dict[str, dict]: +def get_native_ops(config: TOMLDocument, program_features: ProgramFeatures) -> Dict[str, dict]: """Get the gates from the `native` section of the config.""" schema = int(config["schema"]) if schema == 1: - return get_gates(config, ["operators", "gates", 0, "native"], shots_present) + return parse_toml_section(config, ["operators", "gates", 0, "native"], program_features) elif schema == 2: - return get_gates(config, ["operators", "gates", "native"], shots_present) + return parse_toml_section(config, ["operators", "gates", "native"], program_features) raise CompileError(f"Unsupported config schema {schema}") -def get_decomposable_gates(config: TOMLDocument, shots_present: bool) -> Dict[str, dict]: +def get_decomposable_gates( + config: TOMLDocument, program_features: ProgramFeatures +) -> Dict[str, dict]: """Get gates that will be decomposed according to PL's decomposition rules. Args: @@ -149,14 +200,16 @@ def get_decomposable_gates(config: TOMLDocument, shots_present: bool) -> Dict[st """ schema = int(config["schema"]) if schema == 1: - return get_gates(config, ["operators", "gates", 0, "decomp"], shots_present) + return parse_toml_section(config, ["operators", "gates", 0, "decomp"], program_features) elif schema == 2: - return get_gates(config, ["operators", "gates", "decomp"], shots_present) + return parse_toml_section(config, ["operators", "gates", "decomp"], program_features) raise CompileError(f"Unsupported config schema {schema}") -def get_matrix_decomposable_gates(config: TOMLDocument, shots_present: bool) -> Dict[str, dict]: +def get_matrix_decomposable_gates( + config: TOMLDocument, program_features: ProgramFeatures +) -> Dict[str, dict]: """Get gates that will be decomposed to QubitUnitary. Args: @@ -164,8 +217,139 @@ def get_matrix_decomposable_gates(config: TOMLDocument, shots_present: bool) -> """ schema = int(config["schema"]) if schema == 1: - return get_gates(config, ["operators", "gates", 0, "matrix"], shots_present) + return parse_toml_section(config, ["operators", "gates", 0, "matrix"], program_features) elif schema == 2: - return get_gates(config, ["operators", "gates", "matrix"], shots_present) + return parse_toml_section(config, ["operators", "gates", "matrix"], program_features) raise CompileError(f"Unsupported config schema {schema}") + + +def get_operation_properties(config_props: dict) -> OperationProperties: + """Load operation properties from config""" + properties = config_props.get("properties", {}) + return OperationProperties( + invertible="invertible" in properties, + controllable="controllable" in properties, + differentiable="differentiable" in properties, + ) + + +def patch_schema1_collections( + config, device_name, native_gate_props, matrix_decomp_props, decomp_props, observable_props +): # pylint: disable=too-many-arguments, too-many-branches + """For old schema1 config files we deduce some information which was not explicitly encoded.""" + + # TODO: remove after PR #642 is merged in lightning + # NOTE: we mark GlobalPhase as controllables even if `quantum_control` flag is False. This + # is what actual device reports. + if device_name == "lightning.kokkos": # pragma: nocover + native_gate_props["GlobalPhase"] = OperationProperties( + invertible=False, controllable=True, differentiable=True + ) + + # TODO: remove after PR #642 is merged in lightning + if device_name == "lightning.kokkos": # pragma: nocover + observable_props["Projector"] = OperationProperties( + invertible=False, controllable=False, differentiable=False + ) + + # The deduction logic is the following: + # * Most of the gates have their `C(Gate)` controlled counterparts. + # * Some gates have to be decomposed if controlled version is used. Typically these are + # gates which are already controlled but have well-known names. + # * Few gates, like `QubitUnitary`, have separate classes for their controlled versions. + gates_to_be_decomposed_if_controlled = [ + "Identity", + "CNOT", + "CY", + "CZ", + "CSWAP", + "CRX", + "CRY", + "CRZ", + "CRot", + "ControlledPhaseShift", + "QubitUnitary", + "ControlledQubitUnitary", + "Toffoli", + ] + + supports_controlled = check_quantum_control_flag(config) + if supports_controlled: + # Add ControlledQubitUnitary as a controlled version of QubitUnitary + if "QubitUnitary" in native_gate_props: + native_gate_props["ControlledQubitUnitary"] = OperationProperties( + invertible=False, controllable=False, differentiable=True + ) + # By default, enable the `C(gate)` version for most `gates`. + for op, props in native_gate_props.items(): + props.controllable = op not in gates_to_be_decomposed_if_controlled + + supports_adjoint = check_compilation_flag(config, "quantum_adjoint") + if supports_adjoint: + # Makr all gates as invertibles + for props in native_gate_props.values(): + props.invertible = True + + # For toml schema 1 configs, the following condition is possible: (1) `QubitUnitary` gate is + # supported, (2) native quantum control flag is enabled and (3) `ControlledQubitUnitary` is + # listed in either matrix or decomposable sections. This is a contradiction, because + # condition (1) means that `ControlledQubitUnitary` is also in the native set. We solve it + # here by applying a fixup. + # TODO: remove after PR #642 is merged in lightning + if "ControlledQubitUnitary" in native_gate_props: # pragma: nocover + if "ControlledQubitUnitary" in matrix_decomp_props: + matrix_decomp_props.pop("ControlledQubitUnitary") + if "ControlledQubitUnitary" in decomp_props: + decomp_props.pop("ControlledQubitUnitary") + + # Fix a bug in device toml schema 1 + if "ControlledPhaseShift" in native_gate_props: # pragma: nocover + if "ControlledPhaseShift" in matrix_decomp_props: + matrix_decomp_props.pop("ControlledPhaseShift") + if "ControlledPhaseShift" in decomp_props: + decomp_props.pop("ControlledPhaseShift") + + +def get_device_capabilities( + config: TOMLDocument, program_features: ProgramFeatures, device_name: str +) -> DeviceCapabilities: + """Load TOML document into the DeviceCapabilities structure""" + + schema = int(config["schema"]) + + native_gate_props = {} + for g, props in get_native_ops(config, program_features).items(): + native_gate_props[g] = get_operation_properties(props) + + matrix_decomp_props = {} + for g, props in get_matrix_decomposable_gates(config, program_features).items(): + matrix_decomp_props[g] = get_operation_properties(props) + + decomp_props = {} + for g, props in get_decomposable_gates(config, program_features).items(): + decomp_props[g] = get_operation_properties(props) + + observable_props = {} + for g, props in get_observables(config, program_features).items(): + observable_props[g] = get_operation_properties(props) + + if schema == 1: + patch_schema1_collections( + config, + device_name, + native_gate_props, + matrix_decomp_props, + decomp_props, + observable_props, + ) + + return DeviceCapabilities( + native_ops=native_gate_props, + to_decomp_ops=decomp_props, + to_matrix_ops=matrix_decomp_props, + native_obs=observable_props, + mid_circuit_measurement_flag=check_compilation_flag(config, "mid_circuit_measurement"), + runtime_code_generation_flag=check_compilation_flag(config, "runtime_code_generation"), + dynamic_qubit_management_flag=check_compilation_flag(config, "dynamic_qubit_management"), + ) diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py index 8629400bff..15897c272e 100644 --- a/frontend/test/pytest/test_config_functions.py +++ b/frontend/test/pytest/test_config_functions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for functions to check config validity.""" +"""Unit tests for device toml config parsing and validation.""" from os.path import join from tempfile import TemporaryDirectory @@ -21,18 +21,24 @@ import pennylane as qml import pytest +from catalyst.qjit_device import QJITDevice, QJITDeviceNewAPI from catalyst.utils.exceptions import CompileError from catalyst.utils.runtime import ( check_no_overlap, + get_device_capabilities, + validate_config_with_device, +) +from catalyst.utils.toml import ( + DeviceCapabilities, + ProgramFeatures, + TOMLDocument, check_quantum_control_flag, get_decomposable_gates, get_matrix_decomposable_gates, - get_native_gates, - get_pennylane_observables, - get_pennylane_operations, - validate_config_with_device, + get_native_ops, + pennylane_operation_set, + read_toml_file, ) -from catalyst.utils.toml import check_adjoint_flag, read_toml_file class DummyDevice(qml.QubitDevice): @@ -55,6 +61,25 @@ def apply(self, operations, **kwargs): ALL_SCHEMAS = [1, 2] +def get_test_config(config_text: str) -> TOMLDocument: + """Parse test config into the TOMLDocument structure""" + with TemporaryDirectory() as d: + toml_file = join(d, "test.toml") + with open(toml_file, "w", encoding="utf-8") as f: + f.write(config_text) + config = read_toml_file(toml_file) + return config + + +def get_test_device_capabilities( + program_features: ProgramFeatures, config_text: str +) -> DeviceCapabilities: + """Parse test config into the DeviceCapabilities structure""" + config = get_test_config(config_text) + device_capabilities = get_device_capabilities(config, program_features, "dummy") + return device_capabilities + + @pytest.mark.parametrize("schema", ALL_SCHEMAS) def test_validate_config_with_device(schema): """Test error is raised if checking for qjit compatibility and field is false in toml file.""" @@ -83,379 +108,323 @@ def test_validate_config_with_device(schema): def test_get_observables_schema1(): """Test observables are properly obtained from the toml schema 1.""" - with TemporaryDirectory() as d: - test_deduced_gates = {"TestNativeGate"} - - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 1 - [operators] - observables = [ "TestNativeGate" ] - """ - ) - ) - config = read_toml_file(toml_file) - assert test_deduced_gates == get_pennylane_observables(config, False, "device_name") + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + r""" + schema = 1 + [operators] + observables = [ "PauliX" ] + """ + ), + ) + assert {"PauliX"} == pennylane_operation_set(device_capabilities.native_obs) def test_get_observables_schema2(): """Test observables are properly obtained from the toml schema 2.""" - with TemporaryDirectory() as d: - test_deduced_gates = {"TestNativeGate1"} - - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.observables] - TestNativeGate1 = { } - """ - ) - ) - config = read_toml_file(toml_file) - assert test_deduced_gates == get_pennylane_observables(config, False, "device_name") - - -def test_get_native_gates_schema1_no_qcontrol(): + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + r""" + schema = 2 + [operators.observables] + PauliX = { } + """ + ), + ) + assert {"PauliX"} == pennylane_operation_set(device_capabilities.native_obs) + + +def test_get_native_ops_schema1_no_qcontrol(): """Test native gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - test_deduced_gates = {"TestNativeGate"} - - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 1 - [[operators.gates]] - native = [ "TestNativeGate" ] - [compilation] - quantum_control = false - """ - ) - ) - config = read_toml_file(toml_file) - assert test_deduced_gates == get_pennylane_operations(config, False, "device_name") - - -def test_get_native_gates_schema1_qcontrol(): + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + r""" + schema = 1 + [[operators.gates]] + native = [ "PauliX" ] + [compilation] + quantum_control = false + """ + ), + ) + assert {"PauliX"} == pennylane_operation_set(device_capabilities.native_ops) + + +def test_get_native_ops_schema1_qcontrol(): """Test native gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - test_deduced_gates = {"C(TestNativeGate)", "TestNativeGate"} - - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 1 - [[operators.gates]] - native = [ "TestNativeGate" ] - [compilation] - quantum_control = true - """ - ) - ) - config = read_toml_file(toml_file) - assert test_deduced_gates == get_pennylane_operations(config, False, "device_name") - - -def test_get_adjoint_schema2(): + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + r""" + schema = 1 + [[operators.gates]] + native = [ "PauliZ" ] + [compilation] + quantum_control = true + """ + ), + ) + assert {"PauliZ", "C(PauliZ)"} == pennylane_operation_set(device_capabilities.native_ops) + + +@pytest.mark.parametrize("qadjoint", [True, False]) +def test_get_native_ops_schema1_qadjoint(qadjoint): """Test native gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.gates.native] - TestNativeGate1 = { properties = [ 'invertible' ] } - TestNativeGate2 = { properties = [ 'invertible' ] } - """ - ) - ) - config = read_toml_file(toml_file) - assert check_adjoint_flag(config, False) - - -def test_get_native_gates_schema2(): + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + rf""" + schema = 1 + [[operators.gates]] + native = [ "PauliZ" ] + [compilation] + quantum_adjoint = {str(qadjoint).lower()} + """ + ), + ) + assert device_capabilities.native_ops["PauliZ"].invertible is qadjoint + + +def test_get_native_ops_schema2(): """Test native gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - test_deduced_gates = {"C(TestNativeGate1)", "TestNativeGate1", "TestNativeGate2"} - - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.gates.native] - TestNativeGate1 = { properties = [ 'controllable' ] } - TestNativeGate2 = { } - """ - ) - ) - config = read_toml_file(toml_file) - assert test_deduced_gates == get_pennylane_operations(config, False, "device_name") - - -def test_get_native_gates_schema2_optional_shots(): + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + r""" + schema = 2 + [operators.gates.native] + PauliX = { properties = [ 'controllable' ] } + PauliY = { } + """ + ), + ) + + assert {"PauliX", "C(PauliX)", "PauliY"} == pennylane_operation_set( + device_capabilities.native_ops + ) + + +def test_get_native_ops_schema2_optional_shots(): """Test native gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - test_deduced_gates = {"TestNativeGate1"} - - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.gates.native] - TestNativeGate1 = { condition = ['finiteshots'] } - TestNativeGate2 = { condition = ['analytic'] } - """ - ) - ) - config = read_toml_file(toml_file) - assert test_deduced_gates == get_pennylane_operations(config, True, "device_name") - - -def test_get_native_gates_schema2_optional_noshots(): + device_capabilities = get_test_device_capabilities( + ProgramFeatures(True), + dedent( + r""" + schema = 2 + [operators.gates.native] + PauliX = { condition = ['finiteshots'] } + PauliY = { condition = ['analytic'] } + """ + ), + ) + assert "PauliX" in device_capabilities.native_ops + assert "PauliY" not in device_capabilities.native_ops + + +def test_get_native_ops_schema2_optional_noshots(): """Test native gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - test_deduced_gates = {"TestNativeGate2"} - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.gates.native] - TestNativeGate1 = { condition = ['finiteshots'] } - TestNativeGate2 = { condition = ['analytic'] } - """ - ) - ) - config = read_toml_file(toml_file) - assert test_deduced_gates == get_pennylane_operations(config, False, "device") + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + r""" + schema = 2 + [operators.gates.native] + PauliX = { condition = ['finiteshots'] } + PauliY = { condition = ['analytic'] } + """ + ), + ) + assert "PauliX" not in device_capabilities.native_ops + assert "PauliY" in device_capabilities.native_ops def test_get_decomp_gates_schema1(): """Test native decomposition gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - test_gates = {"TestDecompGate": {}} - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - f""" - schema = 1 - [[operators.gates]] - decomp = {str(list(test_gates.keys()))} - """ - ) - ) - - config = read_toml_file(toml_file) - - assert test_gates == get_decomposable_gates(config, False) + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + """ + schema = 1 + [[operators.gates]] + decomp = ["PauliX", "PauliY"] + """ + ), + ) + + assert "PauliX" in device_capabilities.to_decomp_ops + assert "PauliY" in device_capabilities.to_decomp_ops + assert "PauliZ" not in device_capabilities.to_decomp_ops def test_get_decomp_gates_schema2(): """Test native decomposition gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - test_gates = {"TestDecompGate": {}} - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - f""" - schema = 2 - [operators.gates] - decomp = {str(list(test_gates.keys()))} - """ - ) - ) + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + """ + schema = 2 + [operators.gates] + decomp = ["PauliX", "PauliY"] + """ + ), + ) - config = read_toml_file(toml_file) - - assert test_gates == get_decomposable_gates(config, False) + assert "PauliX" in device_capabilities.to_decomp_ops + assert "PauliY" in device_capabilities.to_decomp_ops def test_get_matrix_decomposable_gates_schema1(): """Test native matrix gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - test_gates = {"TestMatrixGate": {}} - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - f""" - schema = 1 - [[operators.gates]] - matrix = {str(list(test_gates.keys()))} - """ - ) - ) - - config = read_toml_file(toml_file) + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + """ + schema = 1 + [[operators.gates]] + matrix = ["PauliX", "PauliY"] + """ + ), + ) - assert test_gates == get_matrix_decomposable_gates(config, False) + assert "PauliX" in device_capabilities.to_matrix_ops + assert "PauliY" in device_capabilities.to_matrix_ops def test_get_matrix_decomposable_gates_schema2(): """Test native matrix gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.gates.matrix] - TestMatrixGate = {} - """ - ) - ) - - config = read_toml_file(toml_file) + device_capabilities = get_test_device_capabilities( + ProgramFeatures(False), + dedent( + r""" + schema = 2 + [operators.gates.matrix] + PauliZ = {} + """ + ), + ) - assert {"TestMatrixGate": {}} == get_matrix_decomposable_gates(config, False) + assert "PauliZ" in device_capabilities.to_matrix_ops def test_check_overlap_msg(): """Test error is raised if there is an overlap in sets.""" - msg = "Device has overlapping gates." + msg = "Device 'test' has overlapping gates." with pytest.raises(CompileError, match=msg): - check_no_overlap(["A"], ["A"], ["A"]) + check_no_overlap(["A"], ["A"], ["A"], device_name="test") def test_config_invalid_attr(): """Check the gate condition handling logic""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.gates.native] - TestGate = { unknown_attribute = 33 } - """ - ) - ) - - config = read_toml_file(toml_file) - - with pytest.raises( - CompileError, match="Configuration for gate 'TestGate' has unknown attributes" - ): - get_native_gates(config, True) + with pytest.raises( + CompileError, match="Configuration for gate 'TestGate' has unknown attributes" + ): + get_test_device_capabilities( + ProgramFeatures(False), + dedent( + r""" + schema = 2 + [operators.gates.native] + TestGate = { unknown_attribute = 33 } + """ + ), + ) def test_config_invalid_condition_unknown(): """Check the gate condition handling logic""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.gates.native] - TestGate = { condition = ["unknown", "analytic"] } - """ - ) - ) - - config = read_toml_file(toml_file) - - with pytest.raises( - CompileError, match="Configuration for gate 'TestGate' has unknown conditions" - ): - get_native_gates(config, True) + with pytest.raises( + CompileError, match="Configuration for gate 'TestGate' has unknown conditions" + ): + get_test_device_capabilities( + ProgramFeatures(True), + dedent( + r""" + schema = 2 + [operators.gates.native] + TestGate = { condition = ["unknown", "analytic"] } + """ + ), + ) def test_config_invalid_property_unknown(): """Check the gate condition handling logic""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.gates.native] - TestGate = { properties = ["unknown", "invertible"] } - """ - ) - ) - - config = read_toml_file(toml_file) - - with pytest.raises( - CompileError, match="Configuration for gate 'TestGate' has unknown properties" - ): - get_native_gates(config, True) - - -def test_config_invalid_condition_duplicate(): + with pytest.raises( + CompileError, match="Configuration for gate 'TestGate' has unknown properties" + ): + get_test_device_capabilities( + ProgramFeatures(True), + dedent( + r""" + schema = 2 + [operators.gates.native] + TestGate = { properties = ["unknown", "invertible"] } + """ + ), + ) + + +@pytest.mark.parametrize("shots", [True, False]) +def test_config_invalid_condition_duplicate(shots): """Check the gate condition handling logic""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 2 - [operators.gates.native] - TestGate = { condition = ["finiteshots", "analytic"] } - """ - ) - ) + with pytest.raises(CompileError, match="Configuration for gate 'TestGate'"): + get_test_device_capabilities( + ProgramFeatures(shots), + dedent( + r""" + schema = 2 + [operators.gates.native] + TestGate = { condition = ["finiteshots", "analytic"] } + """ + ), + ) + + +def test_config_qjit_device_operations(): + """Check the gate condition handling logic""" + config = get_test_config( + dedent( + r""" + schema = 2 + [operators.gates.native] + PauliX = {} + [operators.observables] + PauliY = {} + """ + ), + ) + qjit_device = QJITDevice(config, shots=1000, wires=2) + assert "PauliX" in qjit_device.operations + assert "PauliY" in qjit_device.observables - config = read_toml_file(toml_file) - with pytest.raises(CompileError, match="Configuration for gate 'TestGate'"): - get_native_gates(config, True) +def test_config_unsupported_schema(): + """Test native matrix gates are properly obtained from the toml.""" + program_features = ProgramFeatures(False) + config_text = dedent( + r""" + schema = 999 + """ + ) + config = get_test_config(config_text) - with pytest.raises(CompileError, match="Configuration for gate 'TestGate'"): - get_native_gates(config, False) + with pytest.raises(CompileError): + get_test_device_capabilities(program_features, config_text) + with pytest.raises(CompileError): + get_matrix_decomposable_gates(config, program_features) -def test_config_unsupported_schema(): - """Test native matrix gates are properly obtained from the toml.""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write( - dedent( - r""" - schema = 999 - """ - ) - ) + with pytest.raises(CompileError): + get_decomposable_gates(config, program_features) - config = read_toml_file(toml_file) + with pytest.raises(CompileError): + get_native_ops(config, program_features) - with pytest.raises(CompileError): - check_quantum_control_flag(config) - with pytest.raises(CompileError): - get_native_gates(config, False) - with pytest.raises(CompileError): - get_decomposable_gates(config, False) - with pytest.raises(CompileError): - get_matrix_decomposable_gates(config, False) - with pytest.raises(CompileError): - get_pennylane_operations(config, False, "device_name") - with pytest.raises(CompileError): - check_adjoint_flag(config, False) + with pytest.raises(CompileError): + check_quantum_control_flag(config) if __name__ == "__main__": diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index 367ef2a233..e984933b1e 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -82,6 +82,8 @@ "Adjoint(ISWAP)", "Adjoint(SISWAP)", "MultiControlledX", + "SISWAP", + "ControlledPhaseShift", "C(PauliY)", "C(RY)", "C(PauliX)", diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index 1c9b21b52a..d363cb144a 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -93,10 +93,9 @@ def test_qjit_device(): # Create qjit device config = device_get_toml_config(device) backend_info = extract_backend_info(device, config) - device_qjit = QJITDeviceNewAPI(device, config, backend_info) + device_qjit = QJITDeviceNewAPI(device, backend_info) # Check attributes of the new device - assert isinstance(device_qjit.target_config, dict) assert device_qjit.shots == qml.measurements.Shots(2032) assert device_qjit.wires == qml.wires.Wires(range(0, 10)) @@ -138,7 +137,7 @@ def test_qjit_device_no_wires(): with pytest.raises( AttributeError, match="Catalyst does not support devices without set wires." ): - QJITDeviceNewAPI(device, config, backend_info) + QJITDeviceNewAPI(device, backend_info) @pytest.mark.skipif( diff --git a/runtime/tests/third_party/dummy_device.toml b/runtime/tests/third_party/dummy_device.toml index e4e6ab94be..b52b2b8566 100644 --- a/runtime/tests/third_party/dummy_device.toml +++ b/runtime/tests/third_party/dummy_device.toml @@ -19,6 +19,7 @@ CY = { properties = [ "invertible", "differentiable" ] } CZ = { properties = [ "invertible", "differentiable" ] } PhaseShift = { properties = [ "controllable", "invertible", "differentiable" ] } ControlledPhaseShift = { properties = [ "invertible", "differentiable" ] } +CPhase = { properties = [ "invertible", "differentiable" ] } RX = { properties = [ "controllable", "invertible", "differentiable" ] } RY = { properties = [ "controllable", "invertible", "differentiable" ] } RZ = { properties = [ "controllable", "invertible", "differentiable" ] } @@ -53,11 +54,9 @@ ISWAP = {} PSWAP = {} SISWAP = {} SQISW = {} -CPhase = {} BasisState = {} QubitStateVector = {} StatePrep = {} -ControlledQubitUnitary = {} DiagonalQubitUnitary = {} QubitCarry = {} QubitSum = {}