Skip to content

Commit

Permalink
Change DummyDevice's TOML file to contain per-gate flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei Mironov committed Feb 27, 2024
1 parent 50deeeb commit ddd0a87
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 113 deletions.
25 changes: 3 additions & 22 deletions frontend/catalyst/pennylane_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
JaxTracingContext,
)
from catalyst.utils.exceptions import DifferentiableCompileError
from catalyst.utils.runtime import extract_backend_info, get_lib_path
from catalyst.utils.runtime import extract_backend_info, get_lib_path, load_toml_file_into

Check notice on line 102 in frontend/catalyst/pennylane_extensions.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/pennylane_extensions.py#L102

Unused get_lib_path imported from catalyst.utils.runtime (unused-import)


def _check_no_measurements(tape: QuantumTape) -> None:
Expand Down Expand Up @@ -134,26 +134,6 @@ def __init__(self, fn, device): # pragma: nocover
self.device = device
update_wrapper(self, fn)

@staticmethod
def _add_toml_file(device):
"""Temporary function. This function adds the config field to devices.
TODO: Remove this function when `qml.Device`s are guaranteed to have their own
config file field."""
if hasattr(device, "config"): # pragma: no cover
# Devices that already have a config field do not need it to be overwritten.
return
device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR"))
name = device.name
if isinstance(device, qml.Device):
name = device.short_name

# The toml files name convention we follow is to replace
# the dots with underscores in the device short name.
toml_file_name = name.replace(".", "_") + ".toml"
# And they are currently saved in the following directory.
toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name
device.config = toml_file

@staticmethod
def extract_backend_info(device):
"""Wrapper around extract_backend_info in the runtime module."""
Expand All @@ -163,7 +143,8 @@ def __call__(self, *args, **kwargs):
qnode = None
if isinstance(self, qml.QNode):
qnode = self
QFunc._add_toml_file(self.device)
if not hasattr(self.device, "config"):
load_toml_file_into(self.device)
dev_args = QFunc.extract_backend_info(self.device)
config, rest = dev_args[0], dev_args[1:]
device = QJITDevice(config, self.device.shots, self.device.wires, *rest)
Expand Down
32 changes: 2 additions & 30 deletions frontend/catalyst/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from catalyst.utils.exceptions import CompileError
from catalyst.utils.patching import Patcher
from catalyst.utils.runtime import get_native_gates_PL


class QJITDevice(qml.QubitDevice):
Expand Down Expand Up @@ -96,19 +97,10 @@ def _check_mid_circuit_measurement(config):
def _check_adjoint(config):
return config["compilation"]["quantum_adjoint"]

@staticmethod
def _check_quantum_control(config):
# TODO: Remove the special case when the
# https://github.com/PennyLaneAI/pennylane-lightning/pull/615
# is merged.
if config["device"]["name"] == "lightning.qubit":
return True
return config["compilation"]["quantum_control"]

@staticmethod
def _set_supported_operations(config):
"""Override the set of supported operations."""
native_gates = set(config["operators"]["gates"][0]["native"])
native_gates = get_native_gates_PL(config)
qir_gates = QJITDevice.operations_supported_by_QIR_runtime
supported_native_gates = list(set.intersection(native_gates, qir_gates))
QJITDevice.operations = supported_native_gates
Expand All @@ -122,26 +114,6 @@ def _set_supported_operations(config):
if QJITDevice._check_adjoint(config):
QJITDevice.operations += ["Adjoint"]

if QJITDevice._check_quantum_control(config): # pragma: nocover
# TODO: Once control is added on the frontend.
gates_to_be_decomposed_if_controlled = [
"Identity",
"CNOT",
"CY",
"CZ",
"CSWAP",
"CRX",
"CRY",
"CRZ",
"CRot",
]
native_controlled_gates = [
f"C({gate})"
for gate in native_gates
if gate not in gates_to_be_decomposed_if_controlled
]
QJITDevice.operations += native_controlled_gates

@staticmethod
def _set_supported_observables(config):
"""Override the set of supported observables."""
Expand Down
78 changes: 61 additions & 17 deletions frontend/catalyst/utils/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import pathlib
import platform
import re
from typing import Set, Dict, Any, Tuple

import pennylane as qml

from catalyst._configuration import INSTALLED
from catalyst.utils.exceptions import CompileError
from catalyst.utils.toml import toml_load
from catalyst.utils.toml import toml_load, TOMLDocument

package_root = os.path.dirname(__file__)

Expand Down Expand Up @@ -91,14 +92,24 @@ def check_device_config(device):
raise CompileError(msg)


def get_native_gates(config):
def get_native_gates_PL(config) -> Set[str]:
"""Get gates that are natively supported by the device and therefore do not need to be
decomposed.
Args:
config (Dict[Str, Any]): Configuration dictionary
Returns:
List[str]: List of gate names in the PennyLane format.
"""
return config["operators"]["gates"][0]["native"]
gates = config["operators"]["gates"]["named"]
# import pdb; pdb.set_trace()
gates_PL = set()
for gate_name in [str(g) for g in gates]:
gates_PL.add(f"{gate_name}")
if gates[gate_name].get('controllable', False):
gates_PL.add(f"C({gate_name})")
return gates_PL


def get_decomposable_gates(config):
Expand All @@ -107,7 +118,7 @@ def get_decomposable_gates(config):
Args:
config (Dict[Str, Any]): Configuration dictionary
"""
return config["operators"]["gates"][0]["decomp"]
return config["operators"]["gates"]["decomp"]


def get_matrix_decomposable_gates(config):
Expand All @@ -116,7 +127,7 @@ def get_matrix_decomposable_gates(config):
Args:
config (Dict[Str, Any]): Configuration dictionary
"""
return config["operators"]["gates"][0]["matrix"]
return config["operators"]["gates"]["matrix"]


def check_no_overlap(*args):
Expand Down Expand Up @@ -171,14 +182,20 @@ def check_full_overlap(device, *args):
Raises: CompileError
"""
operations = filter_out_adjoint_and_control(device.operations)
gates_in_device = set(operations)
# operations = filter_out_adjoint_and_control(device.operations)
gates_in_device = set(device.operations)
set_of_sets = [set(arg) for arg in args]
union = set.union(*set_of_sets)
if gates_in_device == union:
gates_in_spec = set.union(*set_of_sets)
if gates_in_device == gates_in_spec:
return

msg = "Gates in qml.device.operations and specification file do not match"
import pdb; pdb.set_trace()

Check notice on line 192 in frontend/catalyst/utils/runtime.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/utils/runtime.py#L192

Import outside toplevel (pdb) (import-outside-toplevel)

Check notice on line 192 in frontend/catalyst/utils/runtime.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/utils/runtime.py#L192

More than one statement on a single line (multiple-statements)

Check notice on line 192 in frontend/catalyst/utils/runtime.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/utils/runtime.py#L192

Leaving functions creating breakpoints in production code is not recommended (forgotten-debug-statement)

msg = (
"Gates in qml.device.operations and specification file do not match.\n"
f"Gates that present only in the device: {gates_in_device - gates_in_spec}\n"
f"Gates that present only in spec: {gates_in_spec - gates_in_device}\n"
)
raise CompileError(msg)


Expand All @@ -191,17 +208,15 @@ def check_gates_are_compatible_with_device(device, config):
Raises: CompileError
"""

native = get_native_gates(config)
native = get_native_gates_PL(config)
decomposable = get_decomposable_gates(config)
matrix = get_matrix_decomposable_gates(config)
check_no_overlap(native, decomposable, matrix)
if not hasattr(device, "operations"): # pragma: nocover

if hasattr(device, "operations"): # pragma: nocover
# The new device API has no "operations" field
# so we cannot check that there's an overlap or not.
return

check_full_overlap(device, native, decomposable, matrix)
check_full_overlap(device, native)


def validate_config_with_device(device):
Expand All @@ -226,7 +241,36 @@ def validate_config_with_device(device):
check_gates_are_compatible_with_device(device, config)


def extract_backend_info(device):
def load_toml_file_into(device, toml_file_name=None):
"""Temporary function. This function adds the `config` field to devices containing the path to
it's TOML configuration file.
TODO: Remove this function when `qml.Device`s are guaranteed to have their own
config file field."""
device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR"))
name = device.name
if isinstance(device, qml.Device):
name = device.short_name

# The toml files name convention we follow is to replace
# the dots with underscores in the device short name.
if toml_file_name is None:
toml_file_name = name.replace(".", "_") + ".toml"
# And they are currently saved in the following directory.
toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name

with open(toml_file, "rb") as f:
config = toml_load(f)

toml_operations = get_native_gates_PL(config)
device.operations = toml_operations
# if not hasattr(device, "operations") or device.operations is None:
# else:
# # TODO: make sure toml_operations matches the device operations
# pass
device.config = toml_file


def extract_backend_info(device) -> Tuple[TOMLDocument, str, str, Dict[str, Any]]:
"""Extract the backend info as a tuple of (name, lib, kwargs)."""

validate_config_with_device(device)
Expand Down
6 changes: 4 additions & 2 deletions frontend/catalyst/utils/toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import importlib.util
from typing import Any

# TODO:
# Once Python version 3.11 is the oldest supported Python version, we can remove tomlkit
Expand All @@ -33,7 +34,8 @@
# Give preference to tomllib
if tomllib:
from tomllib import load as toml_load # pragma: nocover
TOMLDocument = Any
else:
from tomlkit import load as toml_load # pragma: nocover
from tomlkit import load as toml_load, TOMLDocument # pragma: nocover

__all__ = ["toml_load"]
__all__ = ["toml_load", "TOMLDocument"]
12 changes: 7 additions & 5 deletions frontend/test/pytest/test_custom_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from catalyst import measure, qjit
from catalyst.compiler import get_lib_path
from catalyst.utils.exceptions import CompileError
from catalyst.utils.runtime import get_native_gates_PL

Check notice on line 24 in frontend/test/pytest/test_custom_devices.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_custom_devices.py#L24

Unused get_native_gates_PL imported from catalyst.utils.runtime (unused-import)

# These have to match the ones in the configuration file.
OPERATIONS = [
Expand Down Expand Up @@ -97,14 +98,16 @@
"Exp",
]

RUNTIME_LIB_PATH = get_lib_path("runtime", "RUNTIME_LIB_DIR")

@pytest.mark.skipif(
not pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/libdummy_device.so").is_file(),
not pathlib.Path(RUNTIME_LIB_PATH + "/libdummy_device.so").is_file(),
reason="lib_dummydevice.so was not found.",
)
def test_custom_device():
def test_custom_device_load():
"""Test that custom device can run using Catalyst."""


class DummyDevice(qml.QubitDevice):
"""Dummy Device"""

Expand All @@ -114,9 +117,8 @@ class DummyDevice(qml.QubitDevice):
version = "0.0.1"
author = "Dummy"

# Doesn't matter as at the moment it is dictated by QJITDevice
operations = OPERATIONS
observables = OBSERVABLES
operations = [] # To be loaded from the toml file
observables = [] # To be loaded from the toml file

def __init__(self, shots=None, wires=None):
super().__init__(wires=wires, shots=shots)
Expand Down
Loading

0 comments on commit ddd0a87

Please sign in to comment.