Skip to content

Commit

Permalink
Native quantum control toml fixup (#600)
Browse files Browse the repository at this point in the history
**Context:** Ongoing transition to the schema 2 of the quantum device
toml specification

**Description of the Change:** This is a minor source code clean up PR.
* Remove force addition of the ControlQubitUnitary gate to the set of
QJITDevice operations. This gate is handled by the generic logic in
schema 2 config loader. In schema 1 loader, the sources already have all
the required patches.
* Explicitly list the runtime-supported controlled gates
* Fix missing Projector observable in the lightning-kokkos device config
* Open toml files in binary mode (fixes wheel builds)
* Update some comments and simplify the set calculations.
* Includes Cuda hotfix, to be removed after
#617 is merged

**Benefits:** Cleaner sources

**Possible Drawbacks:** N/A

**Related GitHub Issues:**
#554
  • Loading branch information
Sergei Mironov authored Mar 20, 2024
1 parent 1f46e2c commit fe916ea
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 88 deletions.
101 changes: 63 additions & 38 deletions frontend/catalyst/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from catalyst.utils.patching import Patcher
from catalyst.utils.runtime import (
BackendInfo,
deduce_native_controlled_gates,
get_pennylane_observables,
get_pennylane_operations,
)
Expand All @@ -34,62 +33,83 @@
)

RUNTIME_OPERATIONS = {
"CNOT",
"ControlledPhaseShift",
"CRot",
"CRX",
"CRY",
"CRZ",
"CSWAP",
"CY",
"CZ",
"Hadamard",
"Identity",
"IsingXX",
"IsingXY",
"IsingYY",
"ISWAP",
"MultiRZ",
"PauliX",
"PauliY",
"PauliZ",
"Hadamard",
"S",
"T",
"PhaseShift",
"PSWAP",
"QubitUnitary",
"Rot",
"RX",
"RY",
"RZ",
"Rot",
"CNOT",
"CY",
"CZ",
"S",
"SWAP",
"IsingXX",
"IsingYY",
"IsingXY",
"ControlledPhaseShift",
"CRX",
"CRY",
"CRZ",
"CRot",
"CSWAP",
"T",
"Toffoli",
"MultiRZ",
"QubitUnitary",
"ISWAP",
"PSWAP",
"GlobalPhase",
"C(GlobalPhase)",
"C(Hadamard)",
"C(IsingXX)",
"C(IsingXY)",
"C(IsingYY)",
"C(ISWAP)",
"C(MultiRZ)",
"ControlledQubitUnitary",
"C(PauliX)",
"C(PauliY)",
"C(PauliZ)",
"C(PhaseShift)",
"C(PSWAP)",
"C(Rot)",
"C(RX)",
"C(RY)",
"C(RZ)",
"C(S)",
"C(SWAP)",
"C(T)",
}


def get_qjit_pennylane_operations(config: TOMLDocument, shots_present, device_name) -> Set[str]:
"""Get set of supported operations for the QJIT device in the PennyLane format. Take the target
device's config into account."""
def get_qjit_pennylane_operations(
config: TOMLDocument, shots_present: bool, device_name: str
) -> Set[str]:
"""Calculate the set of supported quantum gates for the QJIT device from the gates
allowed on the target quantum device."""
# Supported gates of the target PennyLane's device
native_gates = get_pennylane_operations(config, shots_present, device_name)
qir_gates = set.union(
QJITDeviceNewAPI.operations_supported_by_QIR_runtime,
deduce_native_controlled_gates(QJITDeviceNewAPI.operations_supported_by_QIR_runtime),
)
supported_gates = list(set.intersection(native_gates, qir_gates))
# Gates that Catalyst runtime supports
qir_gates = RUNTIME_OPERATIONS
supported_gates = set.intersection(native_gates, qir_gates)

# These are added unconditionally.
supported_gates += ["Cond", "WhileLoop", "ForLoop"]
# Control-flow gates to be lowered down to the LLVM control-flow instructions
supported_gates.update({"Cond", "WhileLoop", "ForLoop"})

# Optionally enable runtime-powered mid-circuit measurments
if check_mid_circuit_measurement_flag(config): # pragma: no branch
supported_gates += ["MidCircuitMeasure"]
supported_gates.update({"MidCircuitMeasure"})

# Optionally enable runtime-powered quantum gate adjointing (inversions)
if check_adjoint_flag(config, shots_present):
supported_gates += ["Adjoint"]
supported_gates.update({"Adjoint"})

supported_gates += ["ControlledQubitUnitary"]
return set(supported_gates)
return supported_gates


class QJITDevice(qml.QubitDevice):
Expand All @@ -114,8 +134,6 @@ class QJITDevice(qml.QubitDevice):
version = "0.0.1"
author = ""

operations_supported_by_QIR_runtime = RUNTIME_OPERATIONS

@staticmethod
def _get_operations_to_convert_to_matrix(_config: TOMLDocument) -> Set[str]:
# We currently override and only set a few gates to preserve existing behaviour.
Expand Down Expand Up @@ -228,7 +246,14 @@ class QJITDeviceNewAPI(qml.devices.Device):
backend_kwargs (Dict(str, AnyType)): An optional dictionary of the device specifications
"""

operations_supported_by_QIR_runtime = RUNTIME_OPERATIONS
@staticmethod
def _get_operations_to_convert_to_matrix(_config: TOMLDocument) -> Set[str]: # pragma: no cover
# We currently override and only set a few gates to preserve existing behaviour.
# We could choose to read from config and use the "matrix" gates.
# However, that affects differentiability.
# None of the "matrix" gates with more than 2 qubits parameters are differentiable.
# TODO: https://github.com/PennyLaneAI/catalyst/issues/398
return {"MultiControlledX", "BlockEncode"}

def __init__(
self,
Expand Down
31 changes: 19 additions & 12 deletions frontend/catalyst/utils/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
get_matrix_decomposable_gates,
get_native_gates,
get_observables,
toml_load,
read_toml_file,
)

package_root = os.path.dirname(__file__)
Expand Down Expand Up @@ -67,10 +67,16 @@ def get_lib_path(project, env_var):
return os.getenv(env_var, DEFAULT_LIB_PATHS.get(project, ""))


def deduce_native_controlled_gates(native_gates: Set[str]) -> Set[str]:
"""Return the set of controlled gates given the set of nativly supported gates. This function
is used with the toml config schema 1. Later schemas provide the required information directly
def deduce_schema1_native_controlled_gates(native_gates: Set[str]) -> Set[str]:
"""Calculate the set of controlled gates given the set of natively supported gates. This
function is used with the toml config schema 1 which did not support per-gate control
specifications. Later schemas provide the required information directly.
"""
# The deduction logic is the following:
# * Most of the gates have their `C(Gate)` controlled counterparts.
# * Some gates have to be decomposed if controlled version is used. Typically these are gates
# which are already controlled but have well-known names.
# * Few gates, like `QubitUnitary`, have separate classes for their controlled versions.
gates_to_be_decomposed_if_controlled = [
"Identity",
"CNOT",
Expand All @@ -87,7 +93,6 @@ def deduce_native_controlled_gates(native_gates: Set[str]) -> Set[str]:
]
native_controlled_gates = set(
[f"C({gate})" for gate in native_gates if gate not in gates_to_be_decomposed_if_controlled]
# TODO: remove after PR #642 is merged in lightning
+ [f"Controlled{gate}" for gate in native_gates if gate in ["QubitUnitary"]]
)
return native_controlled_gates
Expand Down Expand Up @@ -116,7 +121,7 @@ def get_pennylane_operations(
native_gates = set(native_gates_attrs)
supports_controlled = check_quantum_control_flag(config)
native_controlled_gates = (
deduce_native_controlled_gates(native_gates) if supports_controlled else set()
deduce_schema1_native_controlled_gates(native_gates) if supports_controlled else set()
)

# TODO: remove after PR #642 is merged in lightning
Expand Down Expand Up @@ -145,9 +150,12 @@ def get_pennylane_observables(

observables = set(get_observables(config, shots_present))

# TODO: remove after PR #642 is merged in lightning
if device_name == "lightning.kokkos": # pragma: nocover
observables.update({"Projector"})
schema = int(config["schema"])

if schema == 1:
# TODO: remove after PR #642 is merged in lightning
if device_name == "lightning.kokkos": # pragma: nocover
observables.update({"Projector"})

return observables

Expand Down Expand Up @@ -250,7 +258,7 @@ def validate_config_with_device(device: qml.QubitDevice, config: TOMLDocument) -
# listed in either matrix or decomposable sections. This is a contradiction, because condition
# (1) means that `ControlledQubitUnitary` is also in the native set. We solve it here by
# applying a fixup.
# TODO: Remove when the transition to the toml schema 2 is complete.
# TODO: remove after PR #642 is merged in lightning
if "ControlledQubitUnitary" in native:
matrix = matrix - {"ControlledQubitUnitary"}
decomposable = decomposable - {"ControlledQubitUnitary"}
Expand Down Expand Up @@ -283,8 +291,7 @@ def device_get_toml_config(device) -> Path:
toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name

try:
with open(toml_file, "rb") as f:
config = toml_load(f)
config = read_toml_file(toml_file)
except FileNotFoundError as e:
raise CompileError(
"Attempting to compile program for incompatible device: "
Expand Down
7 changes: 6 additions & 1 deletion frontend/catalyst/utils/toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
from tomlkit import TOMLDocument
from tomlkit import load as toml_load

__all__ = ["toml_load", "TOMLDocument"]

def read_toml_file(toml_file: str) -> TOMLDocument:
"""Helper function opening toml file properly and reading it into a document"""
with open(toml_file, "rb") as f:
config = toml_load(f)
return config


def check_mid_circuit_measurement_flag(config: TOMLDocument) -> bool:
Expand Down
Loading

0 comments on commit fe916ea

Please sign in to comment.