Skip to content

Commit

Permalink
optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
NoureldinYosri committed Oct 24, 2024
1 parent a8277a6 commit 8f17426
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
57 changes: 40 additions & 17 deletions cirq-core/cirq/transformers/insertion_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@

"""Transformer that sorts commuting operations in increasing order of their `.qubits` tuple."""

from typing import Optional, TYPE_CHECKING, List, Tuple
from typing import Optional, TYPE_CHECKING, List, Tuple, FrozenSet, Union

from cirq import protocols, circuits
from cirq.transformers import transformer_api

if TYPE_CHECKING:
import cirq

Check warning on line 23 in cirq-core/cirq/transformers/insertion_sort.py

View check run for this annotation

Codecov / codecov/patch

cirq-core/cirq/transformers/insertion_sort.py#L23

Added line #L23 was not covered by tests


def _id(op: 'cirq.Operation') -> Tuple['cirq.Qid', ...]:
return tuple(sorted(op.qubits))
_MAX_QUBIT_COUNT_FOR_MASK = 64


@transformer_api.transformer(add_deep_support=True)
Expand All @@ -39,21 +37,46 @@ def insertion_sort_transformer(
circuit: input circuit.
context: optional TransformerContext (not used),
"""
operations_with_key: List[Tuple[Tuple['cirq.Qid', ...], 'cirq.Operation']] = [
(_id(op), op) for op in circuit.all_operations()
]
for i in range(len(operations_with_key)):
j = i
all_operations = [*circuit.all_operations()]
relative_order = {
qs: i for i, qs in enumerate(sorted(tuple(sorted(op.qubits)) for op in all_operations))
}
if len(circuit.all_qubits()) <= _MAX_QUBIT_COUNT_FOR_MASK:
# use bitmasks.
q_index = {q: i for i, q in enumerate(circuit.all_qubits())}

def _msk(qs: Tuple['cirq.Qid', ...]) -> int:
msk = 0
for q in qs:
msk |= 1 << q_index[q]
return msk

operations_with_info: Union[
List[Tuple['cirq.Operation', int, int]], List[Tuple['cirq.Operation', int, FrozenSet]]
] = [
(op, relative_order[tuple(sorted(op.qubits))], _msk(op.qubits)) for op in all_operations
]
else:
# use sets.
operations_with_info = [
(op, relative_order[tuple(sorted(op.qubits))], frozenset(op.qubits))
for op in all_operations
]
sorted_info: Union[
List[Tuple['cirq.Operation', int, int]], List[Tuple['cirq.Operation', int, FrozenSet]]
] = []
for i in range(len(all_operations)):
j = len(sorted_info)
while (
j
and operations_with_key[j][0] < operations_with_key[j - 1][0]
and protocols.commutes(
operations_with_key[j][1], operations_with_key[j - 1][1], default=False
and operations_with_info[i][1] < sorted_info[j - 1][1]
and (
not (operations_with_info[i][2] & sorted_info[j - 1][2]) # type: ignore[operator]
or protocols.commutes(
operations_with_info[i][0], sorted_info[j - 1][0], default=False
)
)
):
operations_with_key[j], operations_with_key[j - 1] = (
operations_with_key[j - 1],
operations_with_key[j],
)
j -= 1
return circuits.Circuit(op for _, op in operations_with_key)
sorted_info.insert(j, operations_with_info[i]) # type: ignore[arg-type]
return circuits.Circuit(op for op, _, _ in sorted_info)
6 changes: 5 additions & 1 deletion cirq-core/cirq/transformers/insertion_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import cirq
import cirq.transformers
from cirq.transformers import insertion_sort


def test_insertion_sort():
@pytest.mark.parametrize('qubit_threshold', [1, 64])
def test_insertion_sort(qubit_threshold):
insertion_sort._MAX_QUBIT_COUNT_FOR_MASK = qubit_threshold
c = cirq.Circuit(
cirq.CZ(cirq.q(2), cirq.q(1)),
cirq.CZ(cirq.q(2), cirq.q(4)),
Expand Down

0 comments on commit 8f17426

Please sign in to comment.