-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes for device wire handling #42
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,11 +32,15 @@ | |
~~~~~~~~~~~~ | ||
""" | ||
import abc | ||
from collections.abc import Iterable # pylint: disable=no-name-in-module | ||
from collections import OrderedDict | ||
|
||
import cirq | ||
import numpy as np | ||
import pennylane as qml | ||
from pennylane import QubitDevice | ||
from pennylane.operation import Operation | ||
from pennylane.wires import Wires | ||
|
||
from ._version import __version__ | ||
from .cirq_operation import CirqOperation | ||
|
@@ -51,12 +55,17 @@ class CirqDevice(QubitDevice, abc.ABC): | |
or strings (``['ancilla', 'q1', 'q2']``). | ||
shots (int): Number of circuit evaluations/random samples used | ||
to estimate expectation values of observables. Shots need to be >= 1. | ||
qubits (List[cirq.Qubit]): a list of Cirq qubits that are used | ||
as wires. The wire number corresponds to the index in the list. | ||
By default, an array of ``cirq.LineQubit`` instances is created. | ||
qubits (List[cirq.Qubit]): A list of Cirq qubits that are used | ||
as wires. By default, an array of ``cirq.LineQubit`` instances is created. | ||
Wires are mapped to qubits using Cirq's internal mechanism for ordering | ||
qubits. For example, if ``wires=2`` and ``qubits=[q1, q2]``, with | ||
``q1>q2``, then the wire indices 0 and 1 are mapped to q2 and q1, respectively. | ||
If the user provides their own wire labels, e.g., ``wires=["alice", "bob"]``, and the | ||
qubits are the same as the previous example, then "alice" would map to qubit q2 | ||
and "bob" would map to qubit q1. | ||
""" | ||
|
||
name = "Cirq Abstract PennyLane plugin baseclass" | ||
name = "Cirq Abstract PennyLane plugin base class" | ||
pennylane_requires = ">=0.11.0" | ||
version = __version__ | ||
author = "Xanadu Inc" | ||
|
@@ -69,23 +78,31 @@ class CirqDevice(QubitDevice, abc.ABC): | |
short_name = "cirq.base_device" | ||
|
||
def __init__(self, wires, shots, analytic, qubits=None): | ||
super().__init__(wires, shots, analytic) | ||
|
||
self.circuit = None | ||
|
||
device_wires = self.map_wires(self.wires) | ||
if not isinstance(wires, Iterable): | ||
# interpret wires as the number of consecutive wires | ||
wires = range(wires) | ||
num_wires = len(wires) | ||
|
||
if qubits: | ||
if wires != len(qubits): | ||
if num_wires != len(qubits): | ||
raise qml.DeviceError( | ||
"The number of given qubits and the specified number of wires have to match. Got {} wires and {} qubits.".format( | ||
wires, len(qubits) | ||
) | ||
) | ||
|
||
self.qubits = qubits | ||
else: | ||
self.qubits = [cirq.LineQubit(wire) for wire in device_wires.labels] | ||
qubits = [cirq.LineQubit(idx) for idx in range(num_wires)] | ||
|
||
# cirq orders the subsystems based on a total order defined on qubits. | ||
# For consistency, this plugin uses that same total order | ||
self._unsorted_qubits = qubits | ||
self.qubits = sorted(qubits) | ||
|
||
super().__init__(wires, shots, analytic) | ||
|
||
self.circuit = None | ||
self.cirq_device = None | ||
|
||
# Add inverse operations | ||
self._inverse_operation_map = {} | ||
|
@@ -149,7 +166,10 @@ def reset(self): | |
# pylint: disable=missing-function-docstring | ||
super().reset() | ||
|
||
self.circuit = cirq.Circuit() | ||
if self.cirq_device: | ||
self.circuit = cirq.Circuit(device=self.cirq_device) | ||
else: | ||
self.circuit = cirq.Circuit() | ||
|
||
@property | ||
def observables(self): | ||
|
@@ -226,3 +246,10 @@ def apply(self, operations, **kwargs): | |
# Diagonalize the given observables | ||
for operation in rotations: | ||
self._apply_operation(operation) | ||
|
||
def define_wire_map(self, wires): # pylint: disable=missing-function-docstring | ||
cirq_order = np.argsort(self._unsorted_qubits) | ||
consecutive_wires = Wires(cirq_order) | ||
|
||
wire_map = zip(wires, consecutive_wires) | ||
return OrderedDict(wire_map) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worked really nicely once I figured out that Cirq always uses an ordering internally (the edge case of unordered qubits was never caught before in this plugin) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,9 +132,11 @@ def apply(self, operations, **kwargs): | |
# pylint: disable=missing-function-docstring | ||
super().apply(operations, **kwargs) | ||
|
||
# We apply an identity gate to all wires, otherwise Cirq would ignore | ||
# TODO: remove the need for this hack by keeping better track of unused wires | ||
# We apply identity gates to all wires, otherwise Cirq would ignore | ||
co9olguy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# wires that are not acted upon | ||
self.circuit.append(cirq.IdentityGate(len(self.qubits))(*self.qubits)) | ||
for q in self.qubits: | ||
self.circuit.append(cirq.IdentityGate(1)(q)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, it looks like you are applying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but this is just putting a band-aid over an existing hack. Would be better for us to keep better track of the unused wires in the plugin (which we can now do easily!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have added a TODO in the codebase |
||
|
||
if self.analytic: | ||
self._result = self._simulator.simulate(self.circuit, initial_state=self._initial_state) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why
super
is not called earlier? That would remove a lot of code above, becausenum_wires
can be extracted fromlen(self.wires)
orself.num_wires
. One could still raise an error ifqubits
has a different length.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mariaschuld, here is @co9olguy's original comment: #40 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah maybe that is due to other changes made in the original PR? In that case we can merge as is, to not to create conflicts!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define_wire_map
is called in the parent class's__init__
method.We don't need to know the number of qubits (as you say, that can be inferred), but we do need to do all this qubit pre-processing before
define_wire_map
makes sense to be called