Skip to content

Commit

Permalink
Fixes for device wire handling (#42)
Browse files Browse the repository at this point in the history
* Fixes for device wire handling

* update changelog

* update changelog

* Update pennylane_cirq/simulator_device.py

Co-authored-by: Nathan Killoran <co9olguy@users.noreply.github.com>
  • Loading branch information
josh146 and co9olguy authored Aug 19, 2020
1 parent 2cbdb37 commit 7cfb6fc
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 22 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

* Made plugin device compatible with new PennyLane wire management.
[(#37)](https://github.com/PennyLaneAI/pennylane-cirq/pull/37)
[(#42)](https://github.com/PennyLaneAI/pennylane-cirq/pull/42)

One can now specify any string or number as a custom wire label,
and use these labels to address subsystems on the device:
Expand All @@ -21,7 +22,7 @@

This release contains contributions from (in alphabetical order):

Maria Schuld
Josh Izaac, Nathan Killoran, Maria Schuld

---

Expand Down
53 changes: 40 additions & 13 deletions pennylane_cirq/cirq_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@
~~~~~~~~~~~~
"""
import abc
from collections.abc import Iterable # pylint: disable=no-name-in-module
from collections import OrderedDict

import cirq
import numpy as np
import pennylane as qml
from pennylane import QubitDevice
from pennylane.operation import Operation
from pennylane.wires import Wires

from ._version import __version__
from .cirq_operation import CirqOperation
Expand All @@ -51,12 +55,17 @@ class CirqDevice(QubitDevice, abc.ABC):
or strings (``['ancilla', 'q1', 'q2']``).
shots (int): Number of circuit evaluations/random samples used
to estimate expectation values of observables. Shots need to be >= 1.
qubits (List[cirq.Qubit]): a list of Cirq qubits that are used
as wires. The wire number corresponds to the index in the list.
By default, an array of ``cirq.LineQubit`` instances is created.
qubits (List[cirq.Qubit]): A list of Cirq qubits that are used
as wires. By default, an array of ``cirq.LineQubit`` instances is created.
Wires are mapped to qubits using Cirq's internal mechanism for ordering
qubits. For example, if ``wires=2`` and ``qubits=[q1, q2]``, with
``q1>q2``, then the wire indices 0 and 1 are mapped to q2 and q1, respectively.
If the user provides their own wire labels, e.g., ``wires=["alice", "bob"]``, and the
qubits are the same as the previous example, then "alice" would map to qubit q2
and "bob" would map to qubit q1.
"""

name = "Cirq Abstract PennyLane plugin baseclass"
name = "Cirq Abstract PennyLane plugin base class"
pennylane_requires = ">=0.11.0"
version = __version__
author = "Xanadu Inc"
Expand All @@ -69,23 +78,31 @@ class CirqDevice(QubitDevice, abc.ABC):
short_name = "cirq.base_device"

def __init__(self, wires, shots, analytic, qubits=None):
super().__init__(wires, shots, analytic)

self.circuit = None

device_wires = self.map_wires(self.wires)
if not isinstance(wires, Iterable):
# interpret wires as the number of consecutive wires
wires = range(wires)
num_wires = len(wires)

if qubits:
if wires != len(qubits):
if num_wires != len(qubits):
raise qml.DeviceError(
"The number of given qubits and the specified number of wires have to match. Got {} wires and {} qubits.".format(
wires, len(qubits)
)
)

self.qubits = qubits
else:
self.qubits = [cirq.LineQubit(wire) for wire in device_wires.labels]
qubits = [cirq.LineQubit(idx) for idx in range(num_wires)]

# cirq orders the subsystems based on a total order defined on qubits.
# For consistency, this plugin uses that same total order
self._unsorted_qubits = qubits
self.qubits = sorted(qubits)

super().__init__(wires, shots, analytic)

self.circuit = None
self.cirq_device = None

# Add inverse operations
self._inverse_operation_map = {}
Expand Down Expand Up @@ -149,7 +166,10 @@ def reset(self):
# pylint: disable=missing-function-docstring
super().reset()

self.circuit = cirq.Circuit()
if self.cirq_device:
self.circuit = cirq.Circuit(device=self.cirq_device)
else:
self.circuit = cirq.Circuit()

@property
def observables(self):
Expand Down Expand Up @@ -226,3 +246,10 @@ def apply(self, operations, **kwargs):
# Diagonalize the given observables
for operation in rotations:
self._apply_operation(operation)

def define_wire_map(self, wires): # pylint: disable=missing-function-docstring
cirq_order = np.argsort(self._unsorted_qubits)
consecutive_wires = Wires(cirq_order)

wire_map = zip(wires, consecutive_wires)
return OrderedDict(wire_map)
6 changes: 4 additions & 2 deletions pennylane_cirq/simulator_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ def apply(self, operations, **kwargs):
# pylint: disable=missing-function-docstring
super().apply(operations, **kwargs)

# We apply an identity gate to all wires, otherwise Cirq would ignore
# TODO: remove the need for this hack by keeping better track of unused wires
# We apply identity gates to all wires, otherwise Cirq would ignore
# wires that are not acted upon
self.circuit.append(cirq.IdentityGate(len(self.qubits))(*self.qubits))
for q in self.qubits:
self.circuit.append(cirq.IdentityGate(1)(q))

if self.analytic:
self._result = self._simulator.simulate(self.circuit, initial_state=self._initial_state)
Expand Down
78 changes: 72 additions & 6 deletions tests/test_cirq_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import cirq
import pennylane as qml
from pennylane.wires import Wires
import pytest
import numpy as np

Expand Down Expand Up @@ -51,8 +52,8 @@ def test_default_init_of_qubits(self):
assert dev.qubits[1] == cirq.LineQubit(1)
assert dev.qubits[2] == cirq.LineQubit(2)

def test_outer_init_of_qubits(self):
"""Tests that giving qubits as parameters to CirqDevice works."""
def test_outer_init_of_qubits_ordered(self):
"""Tests that giving qubits as parameters to CirqDevice works when the qubits are already ordered consistently with Cirq's convention."""

qubits = [
cirq.GridQubit(0, 0),
Expand All @@ -63,10 +64,21 @@ def test_outer_init_of_qubits(self):

dev = CirqDevice(4, 100, False, qubits=qubits)
assert len(dev.qubits) == 4
assert dev.qubits[0] == cirq.GridQubit(0, 0)
assert dev.qubits[1] == cirq.GridQubit(0, 1)
assert dev.qubits[2] == cirq.GridQubit(1, 0)
assert dev.qubits[3] == cirq.GridQubit(1, 1)
assert dev.qubits == qubits

def test_outer_init_of_qubits_unordered(self):
"""Tests that giving qubits as parameters to CirqDevice works when the qubits are not ordered consistently with Cirq's convention."""

qubits = [
cirq.GridQubit(0, 1),
cirq.GridQubit(1, 0),
cirq.GridQubit(0, 0),
cirq.GridQubit(1, 1),
]

dev = CirqDevice(4, 100, False, qubits=qubits)
assert len(dev.qubits) == 4
assert dev.qubits == sorted(qubits)

def test_outer_init_of_qubits_error(self):
"""Tests that giving the wrong number of qubits as parameters to CirqDevice raises an error."""
Expand All @@ -85,6 +97,60 @@ def test_outer_init_of_qubits_error(self):
dev = CirqDevice(3, 100, False, qubits=qubits)


class TestCirqDeviceIntegration:
"""Integration tests for Cirq devices"""

def test_outer_init_of_qubits_with_wire_number(self):
"""Tests that giving qubits as parameters to CirqDevice works when the user provides a number of wires."""

unordered_qubits = [
cirq.GridQubit(0, 1),
cirq.GridQubit(1, 0),
cirq.GridQubit(0, 0),
cirq.GridQubit(1, 1),
]

dev = qml.device("cirq.simulator", wires=4, qubits=unordered_qubits)
assert len(dev.qubits) == 4
assert dev.qubits == sorted(unordered_qubits)

def test_outer_init_of_qubits_with_wire_label_strings(self):
"""Tests that giving qubits as parameters to CirqDevice works when the user also provides custom string wire labels."""

unordered_qubits = [
cirq.GridQubit(0, 1),
cirq.GridQubit(1, 0),
cirq.GridQubit(0, 0),
cirq.GridQubit(1, 1),
]

user_labels = ["alice", "bob", "charlie", "david"]
sort_order = [2,0,1,3]

dev = qml.device("cirq.simulator", wires=user_labels, qubits=unordered_qubits)
assert len(dev.qubits) == 4
assert dev.qubits == sorted(unordered_qubits)
assert all(dev.map_wires(Wires(label)) == Wires(idx) for label, idx in zip(user_labels, sort_order))

def test_outer_init_of_qubits_with_wire_label_ints(self):
"""Tests that giving qubits as parameters to CirqDevice works when the user also provides custom integer wire labels."""

unordered_qubits = [
cirq.GridQubit(0, 1),
cirq.GridQubit(1, 0),
cirq.GridQubit(0, 0),
cirq.GridQubit(1, 1),
]

user_labels = [-1,1,66,0]
sort_order = [2,0,1,3]

dev = qml.device("cirq.simulator", wires=user_labels, qubits=unordered_qubits)
assert len(dev.qubits) == 4
assert dev.qubits == sorted(unordered_qubits)
assert all(dev.map_wires(Wires(label)) == Wires(idx) for label, idx in zip(user_labels, sort_order))


@pytest.fixture(scope="function")
def cirq_device_1_wire(shots):
"""A mock instance of the abstract Device class"""
Expand Down

0 comments on commit 7cfb6fc

Please sign in to comment.