Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for device wire handling #42

Merged
merged 4 commits into from
Aug 19, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

* Made plugin device compatible with new PennyLane wire management.
[(#37)](https://github.com/PennyLaneAI/pennylane-cirq/pull/37)
[(#42)](https://github.com/PennyLaneAI/pennylane-cirq/pull/42)

One can now specify any string or number as a custom wire label,
and use these labels to address subsystems on the device:
Expand All @@ -21,7 +22,7 @@

This release contains contributions from (in alphabetical order):

Maria Schuld
Josh Izaac, Nathan Killoran, Maria Schuld

---

Expand Down
53 changes: 40 additions & 13 deletions pennylane_cirq/cirq_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@
~~~~~~~~~~~~
"""
import abc
from collections.abc import Iterable # pylint: disable=no-name-in-module
from collections import OrderedDict

import cirq
import numpy as np
import pennylane as qml
from pennylane import QubitDevice
from pennylane.operation import Operation
from pennylane.wires import Wires

from ._version import __version__
from .cirq_operation import CirqOperation
Expand All @@ -51,12 +55,17 @@ class CirqDevice(QubitDevice, abc.ABC):
or strings (``['ancilla', 'q1', 'q2']``).
shots (int): Number of circuit evaluations/random samples used
to estimate expectation values of observables. Shots need to be >= 1.
qubits (List[cirq.Qubit]): a list of Cirq qubits that are used
as wires. The wire number corresponds to the index in the list.
By default, an array of ``cirq.LineQubit`` instances is created.
qubits (List[cirq.Qubit]): A list of Cirq qubits that are used
as wires. By default, an array of ``cirq.LineQubit`` instances is created.
Wires are mapped to qubits using Cirq's internal mechanism for ordering
qubits. For example, if ``wires=2`` and ``qubits=[q1, q2]``, with
``q1>q2``, then the wire indices 0 and 1 are mapped to q2 and q1, respectively.
If the user provides their own wire labels, e.g., ``wires=["alice", "bob"]``, and the
qubits are the same as the previous example, then "alice" would map to qubit q2
and "bob" would map to qubit q1.
"""

name = "Cirq Abstract PennyLane plugin baseclass"
name = "Cirq Abstract PennyLane plugin base class"
pennylane_requires = ">=0.11.0"
version = __version__
author = "Xanadu Inc"
Expand All @@ -69,23 +78,31 @@ class CirqDevice(QubitDevice, abc.ABC):
short_name = "cirq.base_device"

def __init__(self, wires, shots, analytic, qubits=None):
super().__init__(wires, shots, analytic)

self.circuit = None

device_wires = self.map_wires(self.wires)
if not isinstance(wires, Iterable):
# interpret wires as the number of consecutive wires
wires = range(wires)
num_wires = len(wires)

if qubits:
if wires != len(qubits):
if num_wires != len(qubits):
raise qml.DeviceError(
"The number of given qubits and the specified number of wires have to match. Got {} wires and {} qubits.".format(
wires, len(qubits)
)
)

self.qubits = qubits
else:
self.qubits = [cirq.LineQubit(wire) for wire in device_wires.labels]
qubits = [cirq.LineQubit(idx) for idx in range(num_wires)]

# cirq orders the subsystems based on a total order defined on qubits.
# For consistency, this plugin uses that same total order
self._unsorted_qubits = qubits
self.qubits = sorted(qubits)

super().__init__(wires, shots, analytic)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why super is not called earlier? That would remove a lot of code above, because num_wires can be extracted from len(self.wires) or self.num_wires. One could still raise an error if qubits has a different length.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mariaschuld, here is @co9olguy's original comment: #40 (comment)

I had to move down the superclass initialization here, since it is the thing that creates the wire_map, but the wire_map can't be determined until the cirq qubits have been assigned

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah maybe that is due to other changes made in the original PR? In that case we can merge as is, to not to create conflicts!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define_wire_map is called in the parent class's __init__ method.

We don't need to know the number of qubits (as you say, that can be inferred), but we do need to do all this qubit pre-processing before define_wire_map makes sense to be called


self.circuit = None
self.cirq_device = None

# Add inverse operations
self._inverse_operation_map = {}
Expand Down Expand Up @@ -149,7 +166,10 @@ def reset(self):
# pylint: disable=missing-function-docstring
super().reset()

self.circuit = cirq.Circuit()
if self.cirq_device:
self.circuit = cirq.Circuit(device=self.cirq_device)
else:
self.circuit = cirq.Circuit()

@property
def observables(self):
Expand Down Expand Up @@ -226,3 +246,10 @@ def apply(self, operations, **kwargs):
# Diagonalize the given observables
for operation in rotations:
self._apply_operation(operation)

def define_wire_map(self, wires): # pylint: disable=missing-function-docstring
cirq_order = np.argsort(self._unsorted_qubits)
consecutive_wires = Wires(cirq_order)

wire_map = zip(wires, consecutive_wires)
return OrderedDict(wire_map)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worked really nicely once I figured out that Cirq always uses an ordering internally (the edge case of unordered qubits was never caught before in this plugin)

5 changes: 3 additions & 2 deletions pennylane_cirq/simulator_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ def apply(self, operations, **kwargs):
# pylint: disable=missing-function-docstring
super().apply(operations, **kwargs)

# We apply an identity gate to all wires, otherwise Cirq would ignore
# We apply identity gates to all wires, otherwise Cirq would ignore
co9olguy marked this conversation as resolved.
Show resolved Hide resolved
# wires that are not acted upon
self.circuit.append(cirq.IdentityGate(len(self.qubits))(*self.qubits))
for q in self.qubits:
self.circuit.append(cirq.IdentityGate(1)(q))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand why len(self.qubits) can be replaced by 1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, it looks like you are applying len(self.qubit) single-qubit identity gates, instead of one large identity gate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this is just putting a band-aid over an existing hack. Would be better for us to keep better track of the unused wires in the plugin (which we can now do easily!)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added a TODO in the codebase


if self.analytic:
self._result = self._simulator.simulate(self.circuit, initial_state=self._initial_state)
Expand Down
78 changes: 72 additions & 6 deletions tests/test_cirq_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import cirq
import pennylane as qml
from pennylane.wires import Wires
import pytest
import numpy as np

Expand Down Expand Up @@ -51,8 +52,8 @@ def test_default_init_of_qubits(self):
assert dev.qubits[1] == cirq.LineQubit(1)
assert dev.qubits[2] == cirq.LineQubit(2)

def test_outer_init_of_qubits(self):
"""Tests that giving qubits as parameters to CirqDevice works."""
def test_outer_init_of_qubits_ordered(self):
"""Tests that giving qubits as parameters to CirqDevice works when the qubits are already ordered consistently with Cirq's convention."""

qubits = [
cirq.GridQubit(0, 0),
Expand All @@ -63,10 +64,21 @@ def test_outer_init_of_qubits(self):

dev = CirqDevice(4, 100, False, qubits=qubits)
assert len(dev.qubits) == 4
assert dev.qubits[0] == cirq.GridQubit(0, 0)
assert dev.qubits[1] == cirq.GridQubit(0, 1)
assert dev.qubits[2] == cirq.GridQubit(1, 0)
assert dev.qubits[3] == cirq.GridQubit(1, 1)
assert dev.qubits == qubits

def test_outer_init_of_qubits_unordered(self):
"""Tests that giving qubits as parameters to CirqDevice works when the qubits are not ordered consistently with Cirq's convention."""

qubits = [
cirq.GridQubit(0, 1),
cirq.GridQubit(1, 0),
cirq.GridQubit(0, 0),
cirq.GridQubit(1, 1),
]

dev = CirqDevice(4, 100, False, qubits=qubits)
assert len(dev.qubits) == 4
assert dev.qubits == sorted(qubits)

def test_outer_init_of_qubits_error(self):
"""Tests that giving the wrong number of qubits as parameters to CirqDevice raises an error."""
Expand All @@ -85,6 +97,60 @@ def test_outer_init_of_qubits_error(self):
dev = CirqDevice(3, 100, False, qubits=qubits)


class TestCirqDeviceIntegration:
"""Integration tests for Cirq devices"""

def test_outer_init_of_qubits_with_wire_number(self):
"""Tests that giving qubits as parameters to CirqDevice works when the user provides a number of wires."""

unordered_qubits = [
cirq.GridQubit(0, 1),
cirq.GridQubit(1, 0),
cirq.GridQubit(0, 0),
cirq.GridQubit(1, 1),
]

dev = qml.device("cirq.simulator", wires=4, qubits=unordered_qubits)
assert len(dev.qubits) == 4
assert dev.qubits == sorted(unordered_qubits)

def test_outer_init_of_qubits_with_wire_label_strings(self):
"""Tests that giving qubits as parameters to CirqDevice works when the user also provides custom string wire labels."""

unordered_qubits = [
cirq.GridQubit(0, 1),
cirq.GridQubit(1, 0),
cirq.GridQubit(0, 0),
cirq.GridQubit(1, 1),
]

user_labels = ["alice", "bob", "charlie", "david"]
sort_order = [2,0,1,3]

dev = qml.device("cirq.simulator", wires=user_labels, qubits=unordered_qubits)
assert len(dev.qubits) == 4
assert dev.qubits == sorted(unordered_qubits)
assert all(dev.map_wires(Wires(label)) == Wires(idx) for label, idx in zip(user_labels, sort_order))

def test_outer_init_of_qubits_with_wire_label_ints(self):
"""Tests that giving qubits as parameters to CirqDevice works when the user also provides custom integer wire labels."""

unordered_qubits = [
cirq.GridQubit(0, 1),
cirq.GridQubit(1, 0),
cirq.GridQubit(0, 0),
cirq.GridQubit(1, 1),
]

user_labels = [-1,1,66,0]
sort_order = [2,0,1,3]

dev = qml.device("cirq.simulator", wires=user_labels, qubits=unordered_qubits)
assert len(dev.qubits) == 4
assert dev.qubits == sorted(unordered_qubits)
assert all(dev.map_wires(Wires(label)) == Wires(idx) for label, idx in zip(user_labels, sort_order))


@pytest.fixture(scope="function")
def cirq_device_1_wire(shots):
"""A mock instance of the abstract Device class"""
Expand Down