Skip to content

Commit

Permalink
[Frontend] Make tests toml-schema independent (#712)
Browse files Browse the repository at this point in the history
**Context:** Transition to the quantum device config schema 2

**Description of the Change:** Solve a regarding toml schema 2 udpate in
tests by switching our test custom devices from toml text manipulations
to the device capability manipulations

**Benefits:** 
* Tests no longer require toml text manipulations.
* Tests now contain simple examples of custom devices.
* toml-specific code is now locates in `catalyst.utils.toml`. 

**Possible Drawbacks:**

**Related GitHub Issues:**
PennyLaneAI/pennylane-lightning#642

---------

Co-authored-by: David Ittah <dime10@users.noreply.github.com>
  • Loading branch information
Sergei Mironov and dime10 authored May 24, 2024
1 parent 47b3482 commit 715523e
Show file tree
Hide file tree
Showing 17 changed files with 704 additions and 696 deletions.
12 changes: 8 additions & 4 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,14 @@
annotations.
[(#751)](https://github.com/PennyLaneAI/catalyst/pull/751)

* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable
class that implements the decorator's logic. This prevents having to excessively define
* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable
class that implements the decorator's logic. This prevents having to excessively define
functions in a nested fashion.
[(#758)](https://github.com/PennyLaneAI/catalyst/pull/758)

* Catalyst tests now manipulate device capabilities rather than text configurations files.
[(#712)](https://github.com/PennyLaneAI/catalyst/pull/712)

<h3>Breaking changes</h3>

* Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`.
Expand Down Expand Up @@ -198,11 +201,12 @@
This release contains contributions from (in alphabetical order):

David Ittah,
Mehrdad Malekmohammadi,
Erick Ochoa,
Haochen Paul Wang,
Lee James O'Riordan,
Mehrdad Malekmohammadi,
Raul Torres,
Haochen Paul Wang.
Sergei Mironov.

# Release 0.6.0

Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from catalyst.utils.exceptions import CompileError
from catalyst.utils.filesystem import Directory
from catalyst.utils.runtime import get_lib_path
from catalyst.utils.runtime_environment import get_lib_path

package_root = os.path.dirname(__file__)

Expand Down
11 changes: 10 additions & 1 deletion frontend/catalyst/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,18 @@
Internal API for the device module.
"""

from catalyst.device.qjit_device import QJITDevice, QJITDeviceNewAPI
from catalyst.device.qjit_device import (
BackendInfo,
QJITDevice,
QJITDeviceNewAPI,
extract_backend_info,
validate_device_capabilities,
)

__all__ = (
"QJITDevice",
"QJITDeviceNewAPI",
"BackendInfo",
"extract_backend_info",
"validate_device_capabilities",
)
220 changes: 195 additions & 25 deletions frontend/catalyst/device/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
This module contains device stubs for the old and new PennyLane device API, which facilitate
the application of decomposition and other device pre-processing routines.
"""

import os
import pathlib
import platform
import re
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Optional, Set
from typing import Any, Dict, Optional, Set

import pennylane as qml
from pennylane.measurements import MidMeasureMP
Expand All @@ -32,13 +36,10 @@
)
from catalyst.utils.exceptions import CompileError
from catalyst.utils.patching import Patcher
from catalyst.utils.runtime import BackendInfo, device_get_toml_config
from catalyst.utils.runtime_environment import get_lib_path
from catalyst.utils.toml import (
DeviceCapabilities,
OperationProperties,
ProgramFeatures,
TOMLDocument,
get_device_capabilities,
intersect_operations,
pennylane_operation_set,
)
Expand Down Expand Up @@ -84,6 +85,87 @@
for op in RUNTIME_OPERATIONS
}

# TODO: This should be removed after implementing `get_c_interface`
# for the following backend devices:
SUPPORTED_RT_DEVICES = {
"lightning.qubit": ("LightningSimulator", "librtd_lightning"),
"lightning.kokkos": ("LightningKokkosSimulator", "librtd_lightning"),
"braket.aws.qubit": ("OpenQasmDevice", "librtd_openqasm"),
"braket.local.qubit": ("OpenQasmDevice", "librtd_openqasm"),
}


@dataclass
class BackendInfo:
"""Backend information"""

device_name: str
c_interface_name: str
lpath: str
kwargs: Dict[str, Any]


def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo:
"""Extract the backend info from a quantum device. The device is expected to carry a reference
to a valid TOML config file."""
# pylint: disable=too-many-branches

dname = device.name
if isinstance(device, qml.Device):
dname = device.short_name

device_name = ""
device_lpath = ""
device_kwargs = {}

if dname in SUPPORTED_RT_DEVICES:
# Support backend devices without `get_c_interface`
device_name = SUPPORTED_RT_DEVICES[dname][0]
device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR")
sys_platform = platform.system()

if sys_platform == "Linux":
device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so")
elif sys_platform == "Darwin": # pragma: no cover
device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib")
else: # pragma: no cover
raise NotImplementedError(f"Platform not supported: {sys_platform}")
elif hasattr(device, "get_c_interface"):
# Support third party devices with `get_c_interface`
device_name, device_lpath = device.get_c_interface()
else:
raise CompileError(f"The {dname} device does not provide C interface for compilation.")

if not pathlib.Path(device_lpath).is_file():
raise CompileError(f"Device at {device_lpath} cannot be found!")

if hasattr(device, "shots"):
if isinstance(device, qml.Device):
device_kwargs["shots"] = device.shots if device.shots else 0
else:
# TODO: support shot vectors
device_kwargs["shots"] = device.shots.total_shots if device.shots else 0

if dname == "braket.local.qubit": # pragma: no cover
device_kwargs["device_type"] = dname
device_kwargs["backend"] = (
# pylint: disable=protected-access
device._device._delegate.DEVICE_ID
)
elif dname == "braket.aws.qubit": # pragma: no cover
device_kwargs["device_type"] = dname
device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access
if device._s3_folder: # pylint: disable=protected-access
device_kwargs["s3_destination_folder"] = str(
device._s3_folder # pylint: disable=protected-access
)

for k, v in capabilities.options.items():
if hasattr(device, v):
device_kwargs[k] = getattr(device, v)

return BackendInfo(dname, device_name, device_lpath, device_kwargs)


def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Set[str]:
"""Calculate the set of supported quantum gates for the QJIT device from the gates
Expand Down Expand Up @@ -165,7 +247,7 @@ def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> S

def __init__(
self,
target_config: TOMLDocument,
original_device_capabilities: DeviceCapabilities,
shots=None,
wires=None,
backend: Optional[BackendInfo] = None,
Expand All @@ -175,23 +257,18 @@ def __init__(
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
device_name = backend.device_name if backend else "default"

program_features = ProgramFeatures(shots is not None)
target_device_capabilities = get_device_capabilities(
target_config, program_features, device_name
)
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)
self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities)

@property
def operations(self) -> Set[str]:
"""Get the device operations using PennyLane's syntax"""
return pennylane_operation_set(self.capabilities.native_ops)
return pennylane_operation_set(self.qjit_capabilities.native_ops)

@property
def observables(self) -> Set[str]:
"""Get the device observables"""
return pennylane_operation_set(self.capabilities.native_obs)
return pennylane_operation_set(self.qjit_capabilities.native_obs)

def apply(self, operations, **kwargs):
"""
Expand Down Expand Up @@ -270,6 +347,7 @@ class QJITDeviceNewAPI(qml.devices.Device):
def __init__(
self,
original_device,
original_device_capabilities: DeviceCapabilities,
backend: Optional[BackendInfo] = None,
):
self.original_device = original_device
Expand All @@ -285,29 +363,23 @@ def __init__(
self.backend_name = backend.c_interface_name if backend else "default"
self.backend_lib = backend.lpath if backend else ""
self.backend_kwargs = backend.kwargs if backend else {}
device_name = backend.device_name if backend else "default"

target_config = device_get_toml_config(original_device)
program_features = ProgramFeatures(original_device.shots is not None)
target_device_capabilities = get_device_capabilities(
target_config, program_features, device_name
)
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)
self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities)

@property
def operations(self) -> Set[str]:
"""Get the device operations"""
return pennylane_operation_set(self.capabilities.native_ops)
return pennylane_operation_set(self.qjit_capabilities.native_ops)

@property
def observables(self) -> Set[str]:
"""Get the device observables"""
return pennylane_operation_set(self.capabilities.native_obs)
return pennylane_operation_set(self.qjit_capabilities.native_obs)

@property
def measurement_processes(self) -> Set[str]:
"""Get the device measurement processes"""
return self.capabilities.measurement_processes
return self.qjit_capabilities.measurement_processes

def preprocess(
self,
Expand All @@ -334,3 +406,101 @@ def execute(self, circuits, execution_config):
Raises: RuntimeError
"""
raise RuntimeError("QJIT devices cannot execute tapes.")


def filter_out_adjoint(operations):
"""Remove Adjoint from operations.
Args:
operations (List[Str]): List of strings with names of supported operations
Returns:
List: A list of strings with names of supported operations with Adjoint and C gates
removed.
"""
adjoint = re.compile(r"^Adjoint\(.*\)$")

def is_not_adj(op):
return not re.match(adjoint, op)

operations_no_adj = filter(is_not_adj, operations)
return set(operations_no_adj)


def check_no_overlap(*args, device_name):
"""Check items in *args are mutually exclusive.
Args:
*args (List[Str]): List of strings.
device_name (str): Device name for error reporting.
Raises:
CompileError
"""
set_of_sets = [set(arg) for arg in args]
union = set.union(*set_of_sets)
len_of_sets = [len(arg) for arg in args]
if sum(len_of_sets) == len(union):
return

overlaps = set()
for s in set_of_sets:
overlaps.update(s - union)
union = union - s

msg = f"Device '{device_name}' has overlapping gates: {overlaps}"
raise CompileError(msg)


def validate_device_capabilities(
device: qml.QubitDevice, device_capabilities: DeviceCapabilities
) -> None:
"""Validate configuration document against the device attributes.
Raise CompileError in case of mismatch:
* If device is not qjit-compatible.
* If configuration file does not exists.
* If decomposable, matrix, and native gates have some overlap.
* If decomposable, matrix, and native gates do not match gates in ``device.operations`` and
``device.observables``.
Args:
device (qml.Device): An instance of a quantum device.
config (TOMLDocument): A TOML document representation.
Raises: CompileError
"""

if not device_capabilities.qjit_compatible_flag:
raise CompileError(
f"Attempting to compile program for incompatible device '{device.name}': "
f"Config is not marked as qjit-compatible"
)

device_name = device.short_name if isinstance(device, qml.Device) else device.name

native = pennylane_operation_set(device_capabilities.native_ops)
decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops)
matrix = pennylane_operation_set(device_capabilities.to_matrix_ops)

check_no_overlap(native, decomposable, matrix, device_name=device_name)

if hasattr(device, "operations") and hasattr(device, "observables"):
# For gates, we require strict match
device_gates = filter_out_adjoint(set(device.operations))
spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable))
if device_gates != spec_gates:
raise CompileError(
"Gates in qml.device.operations and specification file do not match.\n"
f"Gates that present only in the device: {device_gates - spec_gates}\n"
f"Gates that present only in spec: {spec_gates - device_gates}\n"
)

# For observables, we do not have `non-native` section in the config, so we check that
# device data supercedes the specification.
device_observables = set(device.observables)
spec_observables = pennylane_operation_set(device_capabilities.native_obs)
if (spec_observables - device_observables) != set():
raise CompileError(
"Observables in qml.device.observables and specification file do not match.\n"
f"Observables that present only in spec: {spec_observables - device_observables}\n"
)
Loading

0 comments on commit 715523e

Please sign in to comment.