Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt commutation checker to abstract circuits #11948

Merged
merged 6 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 85 additions & 50 deletions qiskit/circuit/commutation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
from typing import List, Union
import numpy as np

from qiskit import QiskitError
from qiskit.circuit import Qubit
from qiskit.circuit.operation import Operation
from qiskit.circuit.controlflow import ControlFlowOp
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
from qiskit.quantum_info.operators import Operator

_skipped_op_names = {"measure", "reset", "delay", "initialize"}
_no_cache_op_names = {"annotated"}


@lru_cache(maxsize=None)
def _identity_op(num_qubits):
Expand Down Expand Up @@ -94,8 +98,11 @@ def commute(
)
first_op, first_qargs, _ = first_op_tuple
second_op, second_qargs, _ = second_op_tuple
first_params = first_op.params
second_params = second_op.params

skip_cache = first_op.name in _no_cache_op_names or second_op.name in _no_cache_op_names

if skip_cache:
return _commute_matmul(first_op, first_qargs, second_op, second_qargs)

commutation_lookup = self.check_commutation_entries(
first_op, first_qargs, second_op, second_qargs
Expand All @@ -113,6 +120,8 @@ def commute(
if self._current_cache_entries >= self._cache_max_entries:
self.clear_cached_commutations()

first_params = getattr(first_op, "params", [])
second_params = getattr(second_op, "params", [])
if len(first_params) > 0 or len(second_params) > 0:
self._cached_commutations.setdefault((first_op.name, second_op.name), {}).setdefault(
_get_relative_placement(first_qargs, second_qargs), {}
Expand Down Expand Up @@ -184,7 +193,11 @@ def check_commutation_entries(


def _hashable_parameters(params):
"""Convert the parameters of a gate into a hashable format for lookup in a dictionary."""
"""Convert the parameters of a gate into a hashable format for lookup in a dictionary.

This aims to be fast in common cases, and is not intended to work outside of the lifetime of a
single commutation pass; it does not handle mutable state correctly if the state is actually
changed."""
try:
hash(params)
return params
Expand All @@ -201,7 +214,53 @@ def _hashable_parameters(params):
return ("fallback", str(params))


_skipped_op_names = {"measure", "reset", "delay"}
def is_commutation_supported(op):
"""
Filter operations whose commutation is not supported due to bugs in transpiler passes invoking
commutation analysis.
Args:
op (Operation): operation to be checked for commutation relation
Return:
True if determining the commutation of op is currently supported
"""
# Bug in CommutativeCancellation, e.g. see gh-8553
if getattr(op, "condition", False):
return False

# Commutation of ControlFlow gates also not supported yet. This may be pending a control flow graph.
if op.name in CONTROL_FLOW_OP_NAMES:
return False

return True


def is_commutation_skipped(op, qargs, max_num_qubits):
"""
Filter operations whose commutation will not be determined.
Args:
op (Operation): operation to be checked for commutation relation
qargs (List): operation qubits
max_num_qubits (int): the maximum number of qubits to consider, the check may be skipped if
the number of qubits for either operation exceeds this amount.
Return:
True if determining the commutation of op is currently not supported
"""
if (
len(qargs) > max_num_qubits
or getattr(op, "_directive", False)
or op.name in _skipped_op_names
):
return True

if getattr(op, "is_parameterized", False) and op.is_parameterized():
return True

# we can proceed if op has defined: to_operator, to_matrix and __array__, or if its definition can be
# recursively resolved by operations that have a matrix. We check this by constructing an Operator.
if (hasattr(op, "to_matrix") and hasattr(op, "__array__")) or hasattr(op, "to_operator"):
return False

return False


def _commutation_precheck(
Expand All @@ -213,43 +272,14 @@ def _commutation_precheck(
cargs2: List,
max_num_qubits,
):
# pylint: disable=too-many-return-statements

# We don't support commutation of conditional gates for now due to bugs in
# CommutativeCancellation. See gh-8553.
if getattr(op1, "condition", None) is not None or getattr(op2, "condition", None) is not None:
if not is_commutation_supported(op1) or not is_commutation_supported(op2):
return False

# Commutation of ControlFlow gates also not supported yet. This may be
# pending a control flow graph.
if isinstance(op1, ControlFlowOp) or isinstance(op2, ControlFlowOp):
return False

# These lines are adapted from dag_dependency and say that two gates over
# different quantum and classical bits necessarily commute. This is more
# permissive that the check from commutation_analysis, as for example it
# allows to commute X(1) and Measure(0, 0).
# Presumably this check was not present in commutation_analysis as
# it was only called on pairs of connected nodes from DagCircuit.
intersection_q = set(qargs1).intersection(set(qargs2))
intersection_c = set(cargs1).intersection(set(cargs2))
if not (intersection_q or intersection_c):
if set(qargs1).isdisjoint(qargs2) and set(cargs1).isdisjoint(cargs2):
return True

# Skip the check if the number of qubits for either operation is too large
if len(qargs1) > max_num_qubits or len(qargs2) > max_num_qubits:
return False

# These lines are adapted from commutation_analysis, which is more restrictive than the
# check from dag_dependency when considering nodes with "_directive". It would be nice to
# think which optimizations from dag_dependency can indeed be used.
if op1.name in _skipped_op_names or op2.name in _skipped_op_names:
return False

if getattr(op1, "_directive", False) or getattr(op2, "_directive", False):
return False
if (getattr(op1, "is_parameterized", False) and op1.is_parameterized()) or (
getattr(op2, "is_parameterized", False) and op2.is_parameterized()
if is_commutation_skipped(op1, qargs1, max_num_qubits) or is_commutation_skipped(
op2, qargs2, max_num_qubits
):
return False

Expand All @@ -264,13 +294,11 @@ def _get_relative_placement(first_qargs: List[Qubit], second_qargs: List[Qubit])
second_qargs (DAGOpNode): second gate

Return:
A tuple that describes the relative qubit placement. The relative placement is defined by the
gate qubit arrangements as q2^{-1}[q1[i]] where q1[i] is the ith qubit of the first gate and
q2^{-1}[q] returns the qubit index of qubit q in the second gate (possibly 'None'). E.g.
A tuple that describes the relative qubit placement: E.g.
_get_relative_placement(CX(0, 1), CX(1, 2)) would return (None, 0) as there is no overlap on
the first qubit of the first gate but there is an overlap on the second qubit of the first gate,
i.e. qubit 0 of the second gate. Likewise, _get_relative_placement(CX(1, 2), CX(0, 1)) would
return (1, None)
i.e. qubit 0 of the second gate. Likewise,
_get_relative_placement(CX(1, 2), CX(0, 1)) would return (1, None)
"""
qubits_g2 = {q_g1: i_g1 for i_g1, q_g1 in enumerate(second_qargs)}
return tuple(qubits_g2.get(q_g0, None) for q_g0 in first_qargs)
Expand Down Expand Up @@ -355,8 +383,10 @@ def _query_commutation(
# if we have another dict in commutation_after_placement, commutation depends on params
if isinstance(commutation_after_placement, dict):
# Param commutation entry exists and must be a dict
first_params = getattr(first_op, "params", [])
second_params = getattr(second_op, "params", [])
return commutation_after_placement.get(
(_hashable_parameters(first_op.params), _hashable_parameters(second_op.params)),
(_hashable_parameters(first_params), _hashable_parameters(second_params)),
None,
)
else:
Expand All @@ -379,12 +409,17 @@ def _commute_matmul(
first_qarg = tuple(qarg[q] for q in first_qargs)
second_qarg = tuple(qarg[q] for q in second_qargs)

operator_1 = Operator(
first_ops, input_dims=(2,) * len(first_qarg), output_dims=(2,) * len(first_qarg)
)
operator_2 = Operator(
second_op, input_dims=(2,) * len(second_qarg), output_dims=(2,) * len(second_qarg)
)
# try to generate an Operator out of op, if this succeeds we can determine commutativity, otherwise
# return false
try:
operator_1 = Operator(
first_ops, input_dims=(2,) * len(first_qarg), output_dims=(2,) * len(first_qarg)
)
operator_2 = Operator(
second_op, input_dims=(2,) * len(second_qarg), output_dims=(2,) * len(second_qarg)
)
except QiskitError:
return False

if first_qarg == second_qarg:
# Use full composition if possible to get the fastest matmul paths.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
features:
- |
Extended the commutation analysis performed by :class:`.CommutationChecker` to only operate on
hardware circuits to also work with abstract circuits, i.e. each operation in
the input quantum circuit is now checked for its matrix representation before proceeding to the
analysis. In addition, the operation is now checked for its ability to be cached in the session
commutation library. For example, this now enables computing whether :class:`.AnnotatedOperation`
commute. This enables transpiler passes that rely on :class:`.CommutationChecker` internally,
such as :class:`.CommutativeCancellation`, during earlier stages of a default transpilation pipeline
(prior to basis translation).

34 changes: 33 additions & 1 deletion test/python/circuit/test_commutation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
import numpy as np

from qiskit import ClassicalRegister
from qiskit.circuit import QuantumRegister, Parameter, Qubit
from qiskit.circuit import (
QuantumRegister,
Parameter,
Qubit,
AnnotatedOperation,
InverseModifier,
ControlModifier,
)
from qiskit.circuit.commutation_library import SessionCommutationChecker as scc

from qiskit.circuit.library import (
Expand All @@ -31,6 +38,7 @@
Barrier,
Reset,
LinearFunction,
SGate,
)
from test import QiskitTestCase # pylint: disable=wrong-import-order

Expand Down Expand Up @@ -384,6 +392,30 @@ def test_complex_gates(self):
res = scc.commute(lf3, [0, 1, 2], [], lf4, [0, 1, 2], [])
self.assertTrue(res)

def test_equal_annotated_operations_commute(self):
"""Check commutativity involving the same annotated operation."""
op1 = AnnotatedOperation(SGate(), [InverseModifier(), ControlModifier(1)])
op2 = AnnotatedOperation(SGate(), [InverseModifier(), ControlModifier(1)])
# the same, so true
self.assertTrue(scc.commute(op1, [0, 1], [], op2, [0, 1], []))

def test_annotated_operations_commute_with_unannotated(self):
"""Check commutativity involving annotated operations and unannotated operations."""
op1 = AnnotatedOperation(SGate(), [InverseModifier(), ControlModifier(1)])
op2 = AnnotatedOperation(ZGate(), [InverseModifier()])
op3 = ZGate()
# all true
self.assertTrue(scc.commute(op1, [0, 1], [], op2, [1], []))
self.assertTrue(scc.commute(op1, [0, 1], [], op3, [1], []))
self.assertTrue(scc.commute(op2, [1], [], op3, [1], []))

def test_annotated_operations_no_commute(self):
"""Check non-commutativity involving annotated operations."""
op1 = AnnotatedOperation(XGate(), [InverseModifier(), ControlModifier(1)])
op2 = AnnotatedOperation(XGate(), [InverseModifier()])
# false
self.assertFalse(scc.commute(op1, [0, 1], [], op2, [0], []))

def test_c7x_gate(self):
"""Test wide gate works correctly."""
qargs = [Qubit() for _ in [None] * 8]
Expand Down
Loading