Skip to content

Commit

Permalink
Controlled operations rework Part 1 (#5125)
Browse files Browse the repository at this point in the history
**Context:**
All controlled operations should inherit from the general Controlled
class, and the decomposition of controlled operations is not consistent
for custom and non-custom controlled operations. This is a continuation
of #5069

This is the first PR out of two for this rework. The second PR will
focus on making sure that all custom controlled operations inherit from
Controlled for more consistent inheritance structure.

**Description of the Change:**
- Make `MultiControlledX` inherit from ControlledOp.
- `qml.ctrl` called on operators with custom controlled versions will
return instances of the custom class.
- Special handling of `PauliX` based controlled operations (`PauliX`,
`CNOT`, `Toffoli`, `MultiControlledX`)
- Calling `qml.ctrl` on one of these operators will always resolve to
the best option in `CNOT`, `Toffoli`, or `MultiControlledX` depending on
the number of control wires and control values.
- `qml.ctrl` will flatten nested controlled operators to a single
multi-controlled operation.
- Controlled operators with a custom controlled version decomposes like
how their controlled counterpart decomposes, as opposed to decomposing
into their controlled version.
- Special handling of `PauliX` based controlled operations: e.g.,
`Controlled(CNOT([0, 1]), [2, 3])` will have the same decomposition
behaviour as a `MultiControlledX([2, 3, 0, 1])`

**Benefits:**
Cleaner code and more consistent behaviour

**Possible Drawbacks:**
Change of decomposition behaviour may cause issues.
~For `MultiControlledX`, the `wires` attribute now refers to all wires,
as in `control_wires + target_wire + work_wires`, to access only the
`control_wires + target_wires`, use the `active_wires` attribute.~

**Related GitHub Issues:**
#5069
#1447

**Related Shortcut Stories**
[sc-55949]
[sc-55131]
[sc-55358]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
  • Loading branch information
3 people authored Feb 1, 2024
1 parent 9555cdd commit d0c435d
Show file tree
Hide file tree
Showing 20 changed files with 1,565 additions and 1,253 deletions.
16 changes: 14 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,17 @@
and the codecov check itself would never execute.
[(#5101)](https://github.com/PennyLaneAI/pennylane/pull/5101)

* `qml.ctrl` called on operators with custom controlled versions will return instances
of the custom class, and it will also flatten nested controlled operators to a single
multi-controlled operation. For `PauliX`, `CNOT`, `Toffoli`, and `MultiControlledX`,
calling `qml.ctrl` will always resolve to the best option in `CNOT`, `Toffoli`, or
`MultiControlledX` depending on the number of control wires and control values.
[(#5125)](https://github.com/PennyLaneAI/pennylane/pull/5125/)

* `qml.Identity()` can be initialized without wires. Measuring it is currently not possible though.
[(#5106)](https://github.com/PennyLaneAI/pennylane/pull/5106)


<h4>Community contributions 🥳</h4>

* The transform `split_non_commuting` now accepts measurements of type `probs`, `sample` and `counts` which accept both wires and observables.
Expand Down Expand Up @@ -203,7 +211,12 @@
(with potentially negative eigenvalues) has been implemented.
[(#5048)](https://github.com/PennyLaneAI/pennylane/pull/5048)

* The decomposition of an operator created with calling `qml.ctrl` on a parametric operator (specifically `RX`, `RY`, `RZ`, `Rot`, `PhaseShift`) with a single control wire will now be the full decomposition instead of a single controlled gate. For example:
* Controlled operators with a custom controlled version decomposes like how their
controlled counterpart decomposes, as opposed to decomposing into their controlled version.
[(#5069)](https://github.com/PennyLaneAI/pennylane/pull/5069)
[(#5125)](https://github.com/PennyLaneAI/pennylane/pull/5125/)

For example:
```
>>> qml.ctrl(qml.RX(0.123, wires=1), control=0).decomposition()
[
Expand All @@ -215,7 +228,6 @@
RZ(-1.5707963267948966, wires=[1])
]
```
[(#5069)](https://github.com/PennyLaneAI/pennylane/pull/5069)

* `QuantumScript.is_sampled` and `QuantumScript.all_sampled` have been removed. Users should now
validate these properties manually.
Expand Down
4 changes: 1 addition & 3 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,7 @@ def apply_multicontrolledx(
return _apply_operation_default(op, state, is_state_batched, debugger)
ctrl_wires = [w + is_state_batched for w in op.control_wires]
# apply x on all control wires with control value 0
roll_axes = [
w for val, w in zip(op.hyperparameters["control_values"], ctrl_wires) if val == "0"
]
roll_axes = [w for val, w in zip(op.control_values, ctrl_wires) if val is False]
for ax in roll_axes:
state = math.roll(state, 1, ax)

Expand Down
3 changes: 1 addition & 2 deletions pennylane/drawer/tape_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def _(op: ops.Toffoli, drawer, layer, _):

@_add_operation_to_drawer.register
def _(op: ops.MultiControlledX, drawer, layer, _):
control_values = [(i == "1") for i in op.hyperparameters["control_values"]]
drawer.CNOT(layer, op.wires, control_values=control_values)
drawer.CNOT(layer, op.active_wires, control_values=op.control_values)


@_add_operation_to_drawer.register
Expand Down
3 changes: 2 additions & 1 deletion pennylane/ops/functions/bind_new_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def bind_new_parameters_composite_op(op: CompositeOp, params: Sequence[TensorLik

@bind_new_parameters.register(qml.CY)
@bind_new_parameters.register(qml.CZ)
@bind_new_parameters.register(qml.MultiControlledX)
def bind_new_parameters_copy(
op: Union[qml.CY, qml.CZ], params: Sequence[TensorLike]
op: Union[qml.CY, qml.CZ, qml.MultiControlledX], params: Sequence[TensorLike]
): # pylint:disable=unused-argument
return copy.copy(op)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/ops/op_math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
CRZ,
CY,
CZ,
MultiControlledX,
)
from .decompositions import one_qubit_decomposition, two_qubit_decomposition, sk_decomposition
from .evolution import Evolution
Expand All @@ -124,6 +125,7 @@
"ControlledQubitUnitary",
"CY",
"CZ",
"MultiControlledX",
"CRX",
"CRY",
"CRZ",
Expand Down
217 changes: 184 additions & 33 deletions pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
This submodule defines the symbolic operation that indicates the control of an operator.
"""
import warnings
import functools
from copy import copy
from functools import wraps
from inspect import signature
Expand Down Expand Up @@ -127,36 +128,41 @@ def cond_fn():
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.ctrl(op, control, control_values=control_values, work_wires=work_wires)

custom_ops = {
(qml.PauliZ, 1): qml.CZ,
(qml.PauliY, 1): qml.CY,
(qml.PauliX, 1): qml.CNOT,
(qml.PauliX, 2): qml.Toffoli,
(qml.RX, 1): qml.CRX,
(qml.RY, 1): qml.CRY,
(qml.RZ, 1): qml.CRZ,
(qml.Rot, 1): qml.CRot,
(qml.PhaseShift, 1): qml.ControlledPhaseShift,
}
control_values = [control_values] if isinstance(control_values, (int, bool)) else control_values
control = qml.wires.Wires(control)
custom_key = (type(op), len(control))
if isinstance(control_values, (int, bool)):
control_values = [control_values]
elif control_values is None:
control_values = [True] * len(control)

if custom_key in custom_ops and (control_values is None or all(control_values)):
qml.QueuingManager.remove(op)
return custom_ops[custom_key](*op.data, control + op.wires)
if isinstance(op, qml.PauliX):
ctrl_op = _try_wrap_in_custom_ctrl_op(
op, control, control_values=control_values, work_wires=work_wires
)
if ctrl_op is not None:
return ctrl_op

pauli_x_based_ctrl_ops = _get_pauli_x_based_ops()

# Special handling for PauliX-based controlled operations
if isinstance(op, pauli_x_based_ctrl_ops):
qml.QueuingManager.remove(op)
control_string = (
None if control_values is None else "".join([str(int(v)) for v in control_values])
)
return qml.MultiControlledX(
wires=control + op.wires, control_values=control_string, work_wires=work_wires
return _handle_pauli_x_based_controlled_ops(op, control, control_values, work_wires)

# Flatten nested controlled operations to a multi-controlled operation for better
# decomposition algorithms. This includes special cases like CRX, CRot, etc.
if isinstance(op, Controlled):
work_wires = work_wires or []
return ctrl(
op.base,
control=control + op.control_wires,
control_values=control_values + op.control_values,
work_wires=work_wires + op.work_wires,
)

if isinstance(op, Operator):
return Controlled(
op, control_wires=control, control_values=control_values, work_wires=work_wires
)

if not callable(op):
raise ValueError(
f"The object {op} of type {type(op)} is not an Operator or callable. "
Expand Down Expand Up @@ -190,6 +196,115 @@ def wrapper(*args, **kwargs):
return wrapper


@functools.lru_cache()
def _get_special_ops():
"""Gets a list of special operations with custom controlled versions.
This is placed inside a function to avoid circular imports.
"""

ops_with_custom_ctrl_ops = {
(qml.PauliZ, 1): qml.CZ,
(qml.PauliZ, 2): qml.CCZ,
(qml.PauliY, 1): qml.CY,
(qml.CZ, 1): qml.CCZ,
(qml.SWAP, 1): qml.CSWAP,
(qml.Hadamard, 1): qml.CH,
(qml.RX, 1): qml.CRX,
(qml.RY, 1): qml.CRY,
(qml.RZ, 1): qml.CRZ,
(qml.Rot, 1): qml.CRot,
(qml.PhaseShift, 1): qml.ControlledPhaseShift,
}
return ops_with_custom_ctrl_ops


@functools.lru_cache()
def _get_pauli_x_based_ops():
"""Gets a list of pauli-x based operations
This is placed inside a function to avoid circular imports.
"""
return qml.PauliX, qml.CNOT, qml.Toffoli, qml.MultiControlledX


def _try_wrap_in_custom_ctrl_op(op, control, control_values=None, work_wires=None):
"""Wraps a controlled operation in custom ControlledOp, returns None if not applicable."""

ops_with_custom_ctrl_ops = _get_special_ops()
custom_key = (type(op), len(control))

if custom_key in ops_with_custom_ctrl_ops and all(control_values):
qml.QueuingManager.remove(op)
return ops_with_custom_ctrl_ops[custom_key](*op.data, control + op.wires)

if isinstance(op, qml.QubitUnitary):
return qml.ControlledQubitUnitary(
op, control_wires=control, control_values=control_values, work_wires=work_wires
)

# A controlled ControlledPhaseShift should not be compressed to a multi controlled
# PhaseShift because the decomposition of PhaseShift contains a GlobalPhase that we
# do not have a controlled version of.
# TODO: remove this special case when we support ControlledGlobalPhase (sc-44933)
if isinstance(op, qml.ControlledPhaseShift):
return Controlled(
op, control_wires=control, control_values=control_values, work_wires=work_wires
)
# Similarly, compress the bottom levels of a multi-controlled PhaseShift to a
# ControlledPhaseShift if possible to avoid dealing with a controlled GlobalPhase
# during decomposition. This should also be removed in the future.
if isinstance(op, qml.PhaseShift) and control_values[-1]:
op = qml.ControlledPhaseShift(*op.data, wires=control[-1:] + op.wires)
return Controlled(
op,
control_wires=control[:-1],
control_values=control_values[:-1],
work_wires=work_wires,
)

return None


def _handle_pauli_x_based_controlled_ops(op, control, control_values, work_wires):
"""Handles PauliX-based controlled operations."""

op_map = {
(qml.PauliX, 1): qml.CNOT,
(qml.PauliX, 2): qml.Toffoli,
(qml.CNOT, 1): qml.Toffoli,
}

custom_key = (type(op), len(control))
if custom_key in op_map and all(control_values):
qml.QueuingManager.remove(op)
return op_map[custom_key](wires=control + op.wires)

if isinstance(op, qml.PauliX):
return qml.MultiControlledX(
wires=control + op.wires, control_values=control_values, work_wires=work_wires
)

# TODO: remove special handling of CNOT and Toffoli when they inherit from Controlled
if isinstance(op, qml.CNOT):
return qml.MultiControlledX(
wires=control + op.wires, control_values=control_values + [1], work_wires=work_wires
)
if isinstance(op, qml.Toffoli):
return qml.MultiControlledX(
wires=control + op.wires, control_values=control_values + [1, 1], work_wires=work_wires
)

work_wires = work_wires or []
return qml.MultiControlledX(
wires=control + op.wires,
control_values=control_values + op.control_values,
work_wires=work_wires + op.work_wires,
)


# pylint: disable=too-many-arguments, too-many-public-methods
class Controlled(SymbolicOp):
"""Symbolic operator denoting a controlled operator.
Expand Down Expand Up @@ -531,7 +646,7 @@ def has_decomposition(self):
return True
if len(self.control_wires) == 1 and hasattr(self.base, "_controlled"):
return True
if isinstance(self.base, qml.PauliX):
if isinstance(self.base, _get_pauli_x_based_ops()):
return True
if _is_single_qubit_special_unitary(self.base):
return True
Expand Down Expand Up @@ -631,21 +746,57 @@ def _is_single_qubit_special_unitary(op):
return qmlmath.allclose(det, 1)


# pylint: disable=protected-access
def _decompose_no_control_values(op: "operation.Operator") -> List["operation.Operator"]:
"""Provides a decomposition without considering control values. Returns None if
no decomposition.
"""
def _decompose_pauli_x_based_no_control_values(op: Controlled):
"""Decomposes a PauliX-based operation"""

if isinstance(op.base, qml.PauliX) and len(op.control_wires) == 1:
return [qml.CNOT(wires=op.active_wires)]

if isinstance(op.base, qml.PauliX) and len(op.control_wires) == 2:
return qml.Toffoli.compute_decomposition(wires=op.active_wires)

if isinstance(op.base, qml.CNOT) and len(op.control_wires) == 1:
return qml.Toffoli.compute_decomposition(wires=op.active_wires)

return qml.MultiControlledX.compute_decomposition(
wires=op.active_wires,
work_wires=op.work_wires,
)


def _decompose_custom_ops(op: Controlled) -> List["operation.Operator"]:
"""Custom handling for decomposing a controlled operation"""

pauli_x_based_ctrl_ops = _get_pauli_x_based_ops()
ops_with_custom_ctrl_ops = _get_special_ops()

custom_key = (type(op.base), len(op.control_wires))
if custom_key in ops_with_custom_ctrl_ops:
custom_op_cls = ops_with_custom_ctrl_ops[custom_key]
return custom_op_cls.compute_decomposition(*op.data, op.active_wires)
if type(op.base) in pauli_x_based_ctrl_ops:
# has some special case handling of its own for further decomposition
return _decompose_pauli_x_based_no_control_values(op)

# TODO: will be removed in the second part of the controlled rework
if len(op.control_wires) == 1 and hasattr(op.base, "_controlled"):
result = op.base._controlled(op.control_wires[0])
result = op.base._controlled(op.control_wires[0]) # pylint: disable=protected-access
# disallow decomposing to itself
# pylint: disable=unidiomatic-typecheck
if type(result) != type(op):
return [result]
qml.QueuingManager.remove(result)
if isinstance(op.base, qml.PauliX):
# has some special case handling of its own for further decomposition
return [qml.MultiControlledX(wires=op.active_wires, work_wires=op.work_wires)]

return None


def _decompose_no_control_values(op: Controlled) -> List["operation.Operator"]:
"""Decompose without considering control values. Returns None if no decomposition."""

decomp = _decompose_custom_ops(op)
if decomp is not None:
return decomp

if _is_single_qubit_special_unitary(op.base):
if len(op.control_wires) >= 2 and qmlmath.get_interface(*op.data) == "numpy":
return ctrl_decomp_bisect(op.base, op.control_wires)
Expand All @@ -663,7 +814,7 @@ def _decompose_no_control_values(op: "operation.Operator") -> List["operation.Op
UserWarning,
)

return [Controlled(newop, op.control_wires, work_wires=op.work_wires) for newop in base_decomp]
return [ctrl(newop, op.control_wires, work_wires=op.work_wires) for newop in base_decomp]


class ControlledOp(Controlled, operation.Operation):
Expand Down
Loading

0 comments on commit d0c435d

Please sign in to comment.