Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Controlled operations rework Part 1 #5125

Merged
merged 33 commits into from
Feb 1, 2024
Merged

Controlled operations rework Part 1 #5125

merged 33 commits into from
Feb 1, 2024

Conversation

astralcai
Copy link
Contributor

@astralcai astralcai commented Jan 30, 2024

Context:
All controlled operations should inherit from the general Controlled class, and the decomposition of controlled operations is not consistent for custom and non-custom controlled operations. This is a continuation of #5069

This is the first PR out of two for this rework. The second PR will focus on making sure that all custom controlled operations inherit from Controlled for more consistent inheritance structure.

Description of the Change:

  • Make MultiControlledX inherit from ControlledOp.
  • qml.ctrl called on operators with custom controlled versions will return instances of the custom class.
  • Special handling of PauliX based controlled operations (PauliX, CNOT, Toffoli, MultiControlledX)
    • Calling qml.ctrl on one of these operators will always resolve to the best option in CNOT, Toffoli, or MultiControlledX depending on the number of control wires and control values.
  • qml.ctrl will flatten nested controlled operators to a single multi-controlled operation.
  • Controlled operators with a custom controlled version decomposes like how their controlled counterpart decomposes, as opposed to decomposing into their controlled version.
    • Special handling of PauliX based controlled operations: e.g., Controlled(CNOT([0, 1]), [2, 3]) will have the same decomposition behaviour as a MultiControlledX([2, 3, 0, 1])

Benefits:
Cleaner code and more consistent behaviour

Possible Drawbacks:
Change of decomposition behaviour may cause issues.
For MultiControlledX, the wires attribute now refers to all wires, as in control_wires + target_wire + work_wires, to access only the control_wires + target_wires, use the active_wires attribute.

Related GitHub Issues:
#5069
#1447

Related Shortcut Stories
[sc-55949]
[sc-55131]
[sc-55358]

@astralcai astralcai requested review from dime10 and a team January 30, 2024 15:32
Copy link

codecov bot commented Jan 30, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (9555cdd) 99.69% compared to head (92baab0) 99.68%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5125      +/-   ##
==========================================
- Coverage   99.69%   99.68%   -0.01%     
==========================================
  Files         394      394              
  Lines       36022    35797     -225     
==========================================
- Hits        35911    35685     -226     
- Misses        111      112       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

astralcai and others added 2 commits January 30, 2024 17:30
Co-authored-by: Christina Lee <christina@xanadu.ai>
@astralcai astralcai marked this pull request as ready for review January 31, 2024 16:05
astralcai and others added 2 commits January 31, 2024 14:11
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great 🎉 Thanks for all your hard work :)

Copy link
Contributor

@timmysilv timmysilv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great so far! most of my questions are style-related, but the bulk of this code looks super nice. a lot of work in this I see 🏆

It's a little hard to see what's actually changed in MultiControlledX because of the file move, so I've generated the (reverse of the) diff of the class and attached it here:

The MCX diff, reversed
-class MultiControlledX(ControlledOp):
+class MultiControlledX(Operation):
     r"""MultiControlledX(control_wires, wires, control_values)
     Apply a Pauli X gate controlled on an arbitrary computational basis state.
 
@@ -385,8 +385,8 @@ class MultiControlledX(ControlledOp):
             Now users should use "wires" to indicate both the control wires and the target wire.
         wires (Union[Wires, Sequence[int], or int]): control wire(s) followed by a single target wire where
             the operation acts on
-        control_values (Union[bool, list[bool], int, list[int]]): The value(s) the control wire(s)
-                should take. Integers other than 0 or 1 will be treated as ``int(bool(x))``.
+        control_values (str): a string of bits representing the state of the control
+            wires to control on (default is the all 1s state)
         work_wires (Union[Wires, Sequence[int], or int]): optional work wires used to decompose
             the operation into a series of Toffoli gates
 
@@ -417,66 +417,80 @@ class MultiControlledX(ControlledOp):
     """
 
     is_self_inverse = True
-    """bool: Whether or not the operator is self-inverse."""
 
     num_wires = AnyWires
-    """int: Number of wires the operation acts on."""
 
     num_params = 0
     """int: Number of trainable parameters that the operator depends on."""
 
-    ndim_params = ()
-    """tuple[int]: Number of dimensions per trainable parameter that the operator depends on."""
-
-    name = "MultiControlledX"
+    grad_method = None
 
     def _flatten(self):
-        return (), (self.active_wires, tuple(self.control_values), self.work_wires)
+        hyperparameters = (
+            ("wires", self.wires),
+            ("control_values", self.hyperparameters["control_values"]),
+            ("work_wires", self.hyperparameters["work_wires"]),
+        )
+        return tuple(), hyperparameters
 
     @classmethod
     def _unflatten(cls, _, metadata):
-        return cls(wires=metadata[0], control_values=metadata[1], work_wires=metadata[2])
+        return cls(**dict(metadata))
 
     # pylint: disable=too-many-arguments
     def __init__(self, control_wires=None, wires=None, control_values=None, work_wires=None):
         if wires is None:
             raise ValueError("Must specify the wires where the operation acts on")
-        wires = wires if isinstance(wires, Wires) else Wires(wires)
-        if control_wires is not None:
+        if control_wires is None:
+            if len(wires) > 1:
+                control_wires = Wires(wires[:-1])
+                wires = Wires(wires[-1])
+            else:
+                raise ValueError(
+                    "MultiControlledX: wrong number of wires. "
+                    f"{len(wires)} wire(s) given. Need at least 2."
+                )
+        else:
+            wires = Wires(wires)
+            control_wires = Wires(control_wires)
+
             warnings.warn(
-                "The control_wires keyword will be removed soon. Use wires = (control_wires, "
-                "target_wire) instead. See the documentation for more information.",
-                UserWarning,
+                "The control_wires keyword will be removed soon. "
+                "Use wires = (control_wires, target_wire) instead. "
+                "See the documentation for more information.",
+                category=UserWarning,
             )
+
             if len(wires) != 1:
                 raise ValueError("MultiControlledX accepts a single target wire.")
-        else:
-            if len(wires) < 2:
-                raise ValueError(
-                    f"MultiControlledX: wrong number of wires. {len(wires)} wire(s) given. Need at least 2."
-                )
-            control_wires = wires[:-1]
-            wires = wires[-1:]
 
-        control_values = _check_and_convert_control_values(control_values, control_wires)
+        work_wires = Wires([]) if work_wires is None else Wires(work_wires)
+        total_wires = control_wires + wires
 
-        super().__init__(
-            qml.PauliX(wires=wires),
-            control_wires=control_wires,
-            control_values=control_values,
-            work_wires=work_wires,
-        )
+        if Wires.shared_wires([total_wires, work_wires]):
+            raise ValueError("The work wires must be different from the control and target wires")
+
+        if not control_values:
+            control_values = "1" * len(control_wires)
+
+        self.hyperparameters["control_wires"] = control_wires
+        self.hyperparameters["work_wires"] = work_wires
+        self.hyperparameters["control_values"] = control_values
+        self.total_wires = total_wires
+
+        super().__init__(wires=self.total_wires)
 
     def __repr__(self):
-        return f"MultiControlledX(wires={self.active_wires.tolist()}, control_values={self.control_values})"
+        return f'MultiControlledX(wires={list(self.total_wires._labels)}, control_values="{self.hyperparameters["control_values"]}")'
 
-    @property
-    def wires(self):
-        return self.active_wires
+    def label(self, decimals=None, base_label=None, cache=None):
+        return base_label or "X"
 
-    # pylint: disable=unused-argument, arguments-differ
+    # pylint: disable=unused-argument
     @staticmethod
-    def compute_matrix(control_wires, control_values=None, **kwargs):
+    def compute_matrix(
+        control_wires, control_values=None, **kwargs
+    ):  # pylint: disable=arguments-differ
         r"""Representation of the operator as a canonical matrix in the computational basis (static method).
 
         The canonical matrix is the textbook matrix representation that does not consider wires.
@@ -486,40 +500,57 @@ class MultiControlledX(ControlledOp):
 
         Args:
             control_wires (Any or Iterable[Any]): wires to place controls on
-            control_values (Union[bool, list[bool], int, list[int]]): The value(s) the control wire(s)
-                should take. Integers other than 0 or 1 will be treated as ``int(bool(x))``.
-
+            control_values (str): string of bits determining the controls
         Returns:
-            tensor_like: matrix representation
+           tensor_like: matrix representation
 
         **Example**
 
-        >>> print(qml.MultiControlledX.compute_matrix([0], 1))
+        >>> print(qml.MultiControlledX.compute_matrix([0], '1'))
         [[1. 0. 0. 0.]
          [0. 1. 0. 0.]
          [0. 0. 0. 1.]
          [0. 0. 1. 0.]]
-        >>> print(qml.MultiControlledX.compute_matrix([1], 0))
+        >>> print(qml.MultiControlledX.compute_matrix([1], '0'))
         [[0. 1. 0. 0.]
          [1. 0. 0. 0.]
          [0. 0. 1. 0.]
          [0. 0. 0. 1.]]
 
         """
+        if control_values is None:
+            control_values = "1" * len(control_wires)
+
+        if isinstance(control_values, str):
+            if len(control_values) != len(control_wires):
+                raise ValueError("Length of control bit string must equal number of control wires.")
+
+            # Make sure all values are either 0 or 1
+            if not set(control_values).issubset({"1", "0"}):
+                raise ValueError("String of control values can contain only '0' or '1'.")
 
-        control_values = _check_and_convert_control_values(control_values, control_wires)
-        padding_left = sum(2**i * int(val) for i, val in enumerate(reversed(control_values))) * 2
+            control_int = int(control_values, 2)
+        else:
+            raise ValueError("Control values must be passed as a string.")
+
+        padding_left = control_int * 2
         padding_right = 2 ** (len(control_wires) + 1) - 2 - padding_left
-        return block_diag(np.eye(padding_left), qml.PauliX.compute_matrix(), np.eye(padding_right))
+        cx = block_diag(np.eye(padding_left), PauliX.compute_matrix(), np.eye(padding_right))
+        return cx
 
-    def matrix(self, wire_order=None):
-        canonical_matrix = self.compute_matrix(self.control_wires, self.control_values)
-        wire_order = wire_order or self.wires
-        return qml.math.expand_matrix(
-            canonical_matrix, wires=self.active_wires, wire_order=wire_order
+    @property
+    def control_wires(self):
+        return self.wires[:~0]
+
+    def adjoint(self):
+        return MultiControlledX(
+            wires=self.wires,
+            control_values=self.hyperparameters["control_values"],
         )
 
-    # pylint: disable=unused-argument, arguments-differ
+    def pow(self, z):
+        return super().pow(z % 2)
+
     @staticmethod
     def compute_decomposition(wires=None, work_wires=None, control_values=None, **kwargs):
         r"""Representation of the operator as a product of other operators (static method).
@@ -532,16 +563,14 @@ class MultiControlledX(ControlledOp):
             wires (Iterable[Any] or Wires): wires that the operation acts on
             work_wires (Wires): optional work wires used to decompose
                 the operation into a series of Toffoli gates.
-            control_values (Union[bool, list[bool], int, list[int]]): The value(s) the control wire(s)
-                should take. Integers other than 0 or 1 will be treated as ``int(bool(x))``.
-
+            control_values (str): a string of bits representing the state of the control
+                wires to control on (default is the all 1s state)
         Returns:
             list[Operator]: decomposition into lower level operations
 
         **Example:**
 
-        >>> print(qml.MultiControlledX.compute_decomposition(
-        ...     wires=[0,1,2,3], control_values=[1,1,1], work_wires=qml.wires.Wires("aux")))
+        >>> print(qml.MultiControlledX.compute_decomposition(wires=[0,1,2,3],control_values="111", work_wires=qml.wires.Wires("aux")))
         [Toffoli(wires=[2, 'aux', 3]),
         Toffoli(wires=[0, 1, 'aux']),
         Toffoli(wires=[2, 'aux', 3]),
@@ -549,36 +578,121 @@ class MultiControlledX(ControlledOp):
 
         """
 
-        if len(wires) < 2:
-            raise ValueError(f"Wrong number of wires. {len(wires)} given. Need at least 2.")
-
-        target_wire = wires[-1]
-        control_wires = wires[:-1]
+        target_wire = wires[~0]
+        control_wires = wires[:~0]
 
         if control_values is None:
-            control_values = [True] * len(control_wires)
+            control_values = "1" * len(control_wires)
 
-        work_wires = work_wires or []
         if len(control_wires) > 2 and len(work_wires) == 0:
             raise ValueError(
                 "At least one work wire is required to decompose operation: MultiControlledX"
             )
 
-        flips1 = [qml.PauliX(wires=w) for w, val in zip(control_wires, control_values) if not val]
+        flips1 = [
+            qml.PauliX(control_wires[i]) for i, val in enumerate(control_values) if val == "0"
+        ]
 
         if len(control_wires) == 1:
-            decomp = [qml.CNOT(wires=wires)]
+            decomp = [qml.CNOT(wires=[control_wires[0], target_wire])]
         elif len(control_wires) == 2:
-            decomp = qml.Toffoli.compute_decomposition(wires=wires)
+            decomp = [qml.Toffoli(wires=[*control_wires, target_wire])]
         else:
-            decomp = decompose_mcx(control_wires, target_wire, work_wires)
+            num_work_wires_needed = len(control_wires) - 2
+
+            if len(work_wires) >= num_work_wires_needed:
+                decomp = MultiControlledX._decomposition_with_many_workers(
+                    control_wires, target_wire, work_wires
+                )
+            else:
+                work_wire = work_wires[0]
+                decomp = MultiControlledX._decomposition_with_one_worker(
+                    control_wires, target_wire, work_wire
+                )
 
-        flips2 = [qml.PauliX(wires=w) for w, val in zip(control_wires, control_values) if not val]
+        flips2 = [
+            qml.PauliX(control_wires[i]) for i, val in enumerate(control_values) if val == "0"
+        ]
 
         return flips1 + decomp + flips2
 
-    def decomposition(self):
-        return self.compute_decomposition(self.active_wires, self.work_wires, self.control_values)
+    @staticmethod
+    def _decomposition_with_many_workers(control_wires, target_wire, work_wires):
+        """Decomposes the multi-controlled PauliX gate using the approach in Lemma 7.2 of
+        https://arxiv.org/abs/quant-ph/9503016, which requires a suitably large register of
+        work wires"""
+        num_work_wires_needed = len(control_wires) - 2
+        work_wires = work_wires[:num_work_wires_needed]
+
+        work_wires_reversed = list(reversed(work_wires))
+        control_wires_reversed = list(reversed(control_wires))
+
+        gates = []
+
+        for i in range(len(work_wires)):
+            ctrl1 = control_wires_reversed[i]
+            ctrl2 = work_wires_reversed[i]
+            t = target_wire if i == 0 else work_wires_reversed[i - 1]
+            gates.append(qml.Toffoli(wires=[ctrl1, ctrl2, t]))
+
+        gates.append(qml.Toffoli(wires=[*control_wires[:2], work_wires[0]]))
+
+        for i in reversed(range(len(work_wires))):
+            ctrl1 = control_wires_reversed[i]
+            ctrl2 = work_wires_reversed[i]
+            t = target_wire if i == 0 else work_wires_reversed[i - 1]
+            gates.append(qml.Toffoli(wires=[ctrl1, ctrl2, t]))
+
+        for i in range(len(work_wires) - 1):
+            ctrl1 = control_wires_reversed[i + 1]
+            ctrl2 = work_wires_reversed[i + 1]
+            t = work_wires_reversed[i]
+            gates.append(qml.Toffoli(wires=[ctrl1, ctrl2, t]))
+
+        gates.append(qml.Toffoli(wires=[*control_wires[:2], work_wires[0]]))
+
+        for i in reversed(range(len(work_wires) - 1)):
+            ctrl1 = control_wires_reversed[i + 1]
+            ctrl2 = work_wires_reversed[i + 1]
+            t = work_wires_reversed[i]
+            gates.append(qml.Toffoli(wires=[ctrl1, ctrl2, t]))
+
+        return gates
+
+    @staticmethod
+    def _decomposition_with_one_worker(control_wires, target_wire, work_wire):
+        """Decomposes the multi-controlled PauliX gate using the approach in Lemma 7.3 of
+        https://arxiv.org/abs/quant-ph/9503016, which requires a single work wire"""
+        tot_wires = len(control_wires) + 2
+        partition = int(np.ceil(tot_wires / 2))
+
+        first_part = control_wires[:partition]
+        second_part = control_wires[partition:]
+
+        gates = [
+            MultiControlledX(
+                wires=first_part + work_wire,
+                work_wires=second_part + target_wire,
+            ),
+            MultiControlledX(
+                wires=second_part + work_wire + target_wire,
+                work_wires=first_part,
+            ),
+            MultiControlledX(
+                wires=first_part + work_wire,
+                work_wires=second_part + target_wire,
+            ),
+            MultiControlledX(
+                wires=second_part + work_wire + target_wire,
+                work_wires=first_part,
+            ),
+        ]
+
+        return gates
+
+    @property
+    def is_hermitian(self):
+        return True

pennylane/ops/op_math/controlled.py Outdated Show resolved Hide resolved
pennylane/ops/op_math/controlled.py Show resolved Hide resolved
pennylane/ops/op_math/controlled.py Show resolved Hide resolved
pennylane/ops/op_math/controlled.py Outdated Show resolved Hide resolved
pennylane/ops/op_math/controlled.py Outdated Show resolved Hide resolved
pennylane/ops/op_math/controlled.py Outdated Show resolved Hide resolved
pennylane/ops/op_math/controlled.py Outdated Show resolved Hide resolved
pennylane/ops/op_math/controlled_ops.py Show resolved Hide resolved
pennylane/ops/op_math/controlled_ops.py Outdated Show resolved Hide resolved
pennylane/ops/op_math/controlled_decompositions.py Outdated Show resolved Hide resolved
@astralcai astralcai requested a review from timmysilv February 1, 2024 15:44
Copy link
Contributor

@timmysilv timmysilv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome!

pennylane/ops/op_math/controlled.py Outdated Show resolved Hide resolved
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
@astralcai astralcai enabled auto-merge (squash) February 1, 2024 19:25
@astralcai astralcai merged commit d0c435d into master Feb 1, 2024
35 checks passed
@astralcai astralcai deleted the ctrl-rework branch February 1, 2024 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants