Skip to content

Commit

Permalink
[Frontend] Implement quantum device capabilities as a data structure (#…
Browse files Browse the repository at this point in the history
…609)

In this PR we implement a dedicated data structure representing quantum
device capabilities. Previously we used toml document IR for this
purpose. We see the following benefits in the new approach:

* Data structure simplifies capability calculations we do for our
QJITDevice. Now the capabilities of QJITDevice and of the third-party
devices we support are represented in the same format.
* We reduce the internal usage of `C(Gate)/Adjoint(Gate)` syntax to
minimum. We still use it only to communicate with PennyLane's API


In subsequent PRs we will use it for program verification.

[sc-59478]

---------

Co-authored-by: Romain Moyard <rmoyard@gmail.com>
  • Loading branch information
Sergei Mironov and rmoyard authored Apr 29, 2024
1 parent 0a641cc commit 08e4917
Show file tree
Hide file tree
Showing 8 changed files with 627 additions and 560 deletions.
2 changes: 1 addition & 1 deletion frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(self, *args, **kwargs):
backend_info = QFunc.extract_backend_info(self.device, config)

if isinstance(self.device, qml.devices.Device):
device = QJITDeviceNewAPI(self.device, config, backend_info)
device = QJITDeviceNewAPI(self.device, backend_info)
else:
device = QJITDevice(config, self.device.shots, self.device.wires, backend_info)

Expand Down
179 changes: 94 additions & 85 deletions frontend/catalyst/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""This module contains the qjit device classes.
"""
from copy import deepcopy
from functools import partial
from typing import Optional, Set

Expand All @@ -23,95 +24,101 @@
from catalyst.preprocess import catalyst_acceptance, decompose
from catalyst.utils.exceptions import CompileError
from catalyst.utils.patching import Patcher
from catalyst.utils.runtime import (
BackendInfo,
get_pennylane_observables,
get_pennylane_operations,
)
from catalyst.utils.runtime import BackendInfo, device_get_toml_config
from catalyst.utils.toml import (
DeviceCapabilities,
OperationProperties,
ProgramFeatures,
TOMLDocument,
check_adjoint_flag,
check_mid_circuit_measurement_flag,
get_device_capabilities,
intersect_operations,
pennylane_operation_set,
)

# fmt:off
RUNTIME_OPERATIONS = {
"CNOT",
"ControlledPhaseShift",
"CRot",
"CRX",
"CRY",
"CRZ",
"CSWAP",
"CY",
"CZ",
"Hadamard",
"Identity",
"IsingXX",
"IsingXY",
"IsingYY",
"ISWAP",
"MultiRZ",
"PauliX",
"PauliY",
"PauliZ",
"PhaseShift",
"PSWAP",
"QubitUnitary",
"Rot",
"RX",
"RY",
"RZ",
"S",
"SWAP",
"T",
"Toffoli",
"GlobalPhase",
"C(GlobalPhase)",
"C(Hadamard)",
"C(IsingXX)",
"C(IsingXY)",
"C(IsingYY)",
"C(ISWAP)",
"C(MultiRZ)",
"ControlledQubitUnitary",
"C(PauliX)",
"C(PauliY)",
"C(PauliZ)",
"C(PhaseShift)",
"C(PSWAP)",
"C(Rot)",
"C(RX)",
"C(RY)",
"C(RZ)",
"C(S)",
"C(SWAP)",
"C(T)",
'CNOT': OperationProperties(invertible=True, controllable=True, differentiable=True),
'ControlledPhaseShift':
OperationProperties(invertible=True, controllable=True, differentiable=True),
'CRot': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CRX': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CRY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CRZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CSWAP': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'CZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'Hadamard': OperationProperties(invertible=True, controllable=True, differentiable=True),
'Identity': OperationProperties(invertible=True, controllable=True, differentiable=True),
'IsingXX': OperationProperties(invertible=True, controllable=True, differentiable=True),
'IsingXY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'IsingYY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'ISWAP': OperationProperties(invertible=True, controllable=True, differentiable=True),
'MultiRZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PauliX': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PauliY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PauliZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PhaseShift': OperationProperties(invertible=True, controllable=True, differentiable=True),
'PSWAP': OperationProperties(invertible=True, controllable=True, differentiable=True),
'QubitUnitary': OperationProperties(invertible=True, controllable=True, differentiable=True),
'ControlledQubitUnitary':
OperationProperties(invertible=True, controllable=True, differentiable=True),
'Rot': OperationProperties(invertible=True, controllable=True, differentiable=True),
'RX': OperationProperties(invertible=True, controllable=True, differentiable=True),
'RY': OperationProperties(invertible=True, controllable=True, differentiable=True),
'RZ': OperationProperties(invertible=True, controllable=True, differentiable=True),
'S': OperationProperties(invertible=True, controllable=True, differentiable=True),
'SWAP': OperationProperties(invertible=True, controllable=True, differentiable=True),
'T': OperationProperties(invertible=True, controllable=True, differentiable=True),
'Toffoli': OperationProperties(invertible=True, controllable=True, differentiable=True),
'GlobalPhase': OperationProperties(invertible=True, controllable=True, differentiable=True),
}
# fmt:on


def get_qjit_pennylane_operations(
config: TOMLDocument, shots_present: bool, device_name: str
) -> Set[str]:
def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Set[str]:
"""Calculate the set of supported quantum gates for the QJIT device from the gates
allowed on the target quantum device."""
# Supported gates of the target PennyLane's device
native_gates = get_pennylane_operations(config, shots_present, device_name)
qjit_config = deepcopy(target_capabilities)

# Gates that Catalyst runtime supports
qir_gates = RUNTIME_OPERATIONS
supported_gates = set.intersection(native_gates, qir_gates)

# Intersection of the above
qjit_config.native_ops = intersect_operations(target_capabilities.native_ops, qir_gates)

# Control-flow gates to be lowered down to the LLVM control-flow instructions
supported_gates.update({"Cond", "WhileLoop", "ForLoop"})
qjit_config.native_ops.update(
{
"Cond": OperationProperties(invertible=True, controllable=True, differentiable=True),
"WhileLoop": OperationProperties(
invertible=True, controllable=True, differentiable=True
),
"ForLoop": OperationProperties(invertible=True, controllable=True, differentiable=True),
}
)

# Optionally enable runtime-powered mid-circuit measurments
if check_mid_circuit_measurement_flag(config): # pragma: no branch
supported_gates.update({"MidCircuitMeasure"})
if target_capabilities.mid_circuit_measurement_flag: # pragma: no branch
qjit_config.native_ops.update(
{
"MidCircuitMeasure": OperationProperties(
invertible=True, controllable=True, differentiable=True
)
}
)

# Optionally enable runtime-powered quantum gate adjointing (inversions)
if check_adjoint_flag(config, shots_present):
supported_gates.update({"Adjoint"})
if all(ng.invertible for ng in target_capabilities.native_ops.values()):
qjit_config.native_ops.update(
{
"Adjoint": OperationProperties(
invertible=True, controllable=True, differentiable=True
)
}
)

return supported_gates
return qjit_config


class QJITDevice(qml.QubitDevice):
Expand All @@ -137,7 +144,7 @@ class QJITDevice(qml.QubitDevice):
author = ""

@staticmethod
def _get_operations_to_convert_to_matrix(_config: TOMLDocument) -> Set[str]:
def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> Set[str]:
# We currently override and only set a few gates to preserve existing behaviour.
# We could choose to read from config and use the "matrix" gates.
# However, that affects differentiability.
Expand All @@ -154,25 +161,26 @@ def __init__(
):
super().__init__(wires=wires, shots=shots)

self.target_config = target_config
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
device_name = backend.device_name if backend else "default"

shots_present = shots is not None
self._operations = get_qjit_pennylane_operations(target_config, shots_present, device_name)
self._observables = get_pennylane_observables(target_config, shots_present, device_name)
program_features = ProgramFeatures(shots is not None)
target_device_capabilities = get_device_capabilities(
target_config, program_features, device_name
)
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)

@property
def operations(self) -> Set[str]:
"""Get the device operations"""
return self._operations
"""Get the device operations using PennyLane's syntax"""
return pennylane_operation_set(self.capabilities.native_ops)

@property
def observables(self) -> Set[str]:
"""Get the device observables"""
return self._observables
return pennylane_operation_set(self.capabilities.native_obs)

def apply(self, operations, **kwargs):
"""
Expand Down Expand Up @@ -202,7 +210,7 @@ def default_expand_fn(self, circuit, max_expansion=10):
raise CompileError("Must use 'measure' from Catalyst instead of PennyLane.")

decompose_to_qubit_unitary = QJITDevice._get_operations_to_convert_to_matrix(
self.target_config
self.capabilities
)

def _decomp_to_unitary(self, *_args, **_kwargs):
Expand Down Expand Up @@ -251,7 +259,6 @@ class QJITDeviceNewAPI(qml.devices.Device):
def __init__(
self,
original_device,
target_config: TOMLDocument,
backend: Optional[BackendInfo] = None,
):
self.original_device = original_device
Expand All @@ -264,25 +271,27 @@ def __init__(

super().__init__(wires=original_device.wires, shots=original_device.shots)

self.target_config = target_config
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
device_name = backend.device_name if backend else "default"

shots_present = original_device.shots is not None
self._operations = get_qjit_pennylane_operations(target_config, shots_present, device_name)
self._observables = get_pennylane_observables(target_config, shots_present, device_name)
target_config = device_get_toml_config(original_device)
program_features = ProgramFeatures(original_device.shots is not None)
target_device_capabilities = get_device_capabilities(
target_config, program_features, device_name
)
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)

@property
def operations(self) -> Set[str]:
"""Get the device operations"""
return self._operations
return pennylane_operation_set(self.capabilities.native_ops)

@property
def observables(self) -> Set[str]:
"""Get the device observables"""
return self._observables
return pennylane_operation_set(self.capabilities.native_obs)

def preprocess(
self,
Expand Down
Loading

0 comments on commit 08e4917

Please sign in to comment.