-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix issue when using a wire subset with BasisState
#61
Merged
Merged
Changes from 11 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
632c71c
add expand_state
thisac 845ff57
move basis state checks
thisac e19b129
update changelog
thisac 1a7b49c
Merge branch 'master' into fix-using-wire-subset
thisac 17c7e75
update apply state vector
thisac 53d18e8
Merge branch 'fix-using-wire-subset' of github.com:PennyLaneAI/pennyl…
thisac 4fb594d
add tests
thisac ff83b4d
add expand state test
thisac d5f2826
remove try-except
thisac 52269b6
fix single-qubit var tests
thisac dc88464
remove import
thisac ff90b90
apply changes from code review
thisac File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,8 @@ | |
---- | ||
""" | ||
import math | ||
import itertools as it | ||
|
||
import cirq | ||
import numpy as np | ||
import pennylane as qml | ||
|
@@ -81,26 +83,47 @@ def _apply_basis_state(self, basis_state_operation): | |
if not self.shots is None: | ||
raise qml.DeviceError("The operation BasisState is only supported in analytic mode.") | ||
|
||
basis_state_array = np.array(basis_state_operation.parameters[0]) | ||
wires = basis_state_operation.wires | ||
|
||
if len(basis_state_array) != len(self.qubits): | ||
if len(basis_state_operation.parameters[0]) != len(wires): | ||
raise qml.DeviceError( | ||
"For BasisState, the state has to be specified for the correct number of qubits. Got a state for {} qubits, expected {}.".format( | ||
len(basis_state_array), len(self.qubits) | ||
len(basis_state_operation.parameters[0]), len(self.qubits) | ||
) | ||
) | ||
|
||
if not np.all(np.isin(basis_state_array, np.array([0, 1]))): | ||
if not np.all(np.isin(basis_state_operation.parameters[0], np.array([0, 1]))): | ||
raise qml.DeviceError( | ||
"Argument for BasisState can only contain 0 and 1. Got {}".format( | ||
basis_state_operation.parameters[0] | ||
) | ||
) | ||
|
||
# expand basis state to device wires | ||
basis_state_array = np.zeros(self.num_wires, dtype=int) | ||
basis_state_array[wires] = basis_state_operation.parameters[0] | ||
|
||
self._initial_state = np.zeros(2 ** len(self.qubits), dtype=np.complex64) | ||
basis_state_idx = np.sum(2 ** np.argwhere(np.flip(basis_state_array) == 1)) | ||
self._initial_state[basis_state_idx] = 1.0 | ||
|
||
def _expand_state(self, state_vector, wires): | ||
"""Expands state vector to more wires""" | ||
basis_states = np.array(list(it.product([0, 1], repeat=len(wires)))) | ||
|
||
# get basis states to alter on full set of qubits | ||
unravelled_indices = np.zeros((2 ** len(wires), self.num_wires), dtype=int) | ||
unravelled_indices[:, wires] = basis_states | ||
|
||
# get indices for which the state is changed to input state vector elements | ||
ravelled_indices = np.ravel_multi_index(unravelled_indices.T, [2] * self.num_wires) | ||
|
||
state_vector = self._scatter(ravelled_indices, state_vector, [2 ** self.num_wires]) | ||
state_vector = self._reshape(state_vector, [2] * self.num_wires) | ||
state_vector = self._asarray(state_vector, dtype=self.C_DTYPE) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here as over on the Qulacs PR, recommend changing this to simply use native NumPy functions :) |
||
return state_vector.flatten() | ||
|
||
def _apply_qubit_state_vector(self, qubit_state_vector_operation): | ||
# pylint: disable=missing-function-docstring | ||
if not self.shots is None: | ||
|
@@ -109,11 +132,15 @@ def _apply_qubit_state_vector(self, qubit_state_vector_operation): | |
) | ||
|
||
state_vector = np.array(qubit_state_vector_operation.parameters[0], dtype=np.complex64) | ||
wires = self.map_wires(qubit_state_vector_operation.wires) | ||
|
||
if len(wires) != self.num_wires or sorted(wires) != wires.tolist(): | ||
state_vector = self._expand_state(state_vector, wires) | ||
|
||
if len(state_vector) != 2 ** len(self.qubits): | ||
raise qml.DeviceError( | ||
"For QubitStateVector, the state has to be specified for the correct number of qubits. Got a state of length {}, expected {}.".format( | ||
len(state_vector), 2 ** len(self.qubits) | ||
len(state_vector), 2 ** len(wires) | ||
) | ||
) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,7 +46,7 @@ def test_custom_simulator(self): | |
def circuit(): | ||
qml.PauliX(0) | ||
return qml.expval(qml.PauliX(0)) | ||
|
||
assert circuit() == 0.0 | ||
|
||
|
||
|
@@ -485,6 +485,27 @@ def test_qubit_state_vector_not_at_beginning_error(self, simulator_device_1_wire | |
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"state, device_wires, op_wires, expected", | ||
[ | ||
(np.array([1, 0]), 2, [0], [1, 0, 0, 0]), | ||
(np.array([0, 1]), 2, [0], [0, 0, 1, 0]), | ||
(np.array([1, 1]) / np.sqrt(2), 2, [1], np.array([1, 1, 0, 0]) / np.sqrt(2)), | ||
(np.array([1, 1]) / np.sqrt(2), 3, [0], np.array([1, 0, 0, 0, 1, 0, 0, 0]) / np.sqrt(2)), | ||
(np.array([1, 2, 3, 4]) / np.sqrt(48), 3, [0, 1], np.array([1, 0, 2, 0, 3, 0, 4, 0]) / np.sqrt(48)), | ||
(np.array([1, 2, 3, 4]) / np.sqrt(48), 3, [1, 0], np.array([1, 0, 3, 0, 2, 0, 4, 0]) / np.sqrt(48)), | ||
(np.array([1, 2, 3, 4]) / np.sqrt(48), 3, [0, 2], np.array([1, 2, 0, 0, 3, 4, 0, 0]) / np.sqrt(48)), | ||
(np.array([1, 2, 3, 4]) / np.sqrt(48), 3, [1, 2], np.array([1, 2, 3, 4, 0, 0, 0, 0]) / np.sqrt(48)), | ||
], | ||
) | ||
@pytest.mark.parametrize("shots", [None]) | ||
def test_expand_state(state, op_wires, device_wires, expected, tol): | ||
"""Test that the expand_state method works as expected.""" | ||
dev = SimulatorDevice(device_wires) | ||
res = dev._expand_state(state, op_wires) | ||
|
||
assert np.allclose(res, expected, **tol) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great tests! |
||
|
||
@pytest.mark.parametrize("shots", [1000]) | ||
class TestStatePreparationErrorsNonAnalytic: | ||
"""Tests state preparation errors that occur for non-analytic devices.""" | ||
|
@@ -690,7 +711,7 @@ def test_var_single_wire_no_parameters( | |
|
||
simulator_device_1_wire.reset() | ||
simulator_device_1_wire.apply( | ||
[qml.QubitStateVector(np.array(input), wires=[0, 1])], | ||
[qml.QubitStateVector(np.array(input), wires=[0])], | ||
rotations=op.diagonalizing_gates(), | ||
) | ||
|
||
|
@@ -721,7 +742,7 @@ def test_var_single_wire_with_parameters( | |
|
||
simulator_device_1_wire.reset() | ||
simulator_device_1_wire.apply( | ||
[qml.QubitStateVector(np.array(input), wires=[0, 1])], | ||
[qml.QubitStateVector(np.array(input), wires=[0])], | ||
rotations=op.diagonalizing_gates(), | ||
) | ||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this is a permutation of converting from base2 to base10 that I hadn't seen before!
Interestingly, however, this seems to be a rare case where the native Python approach is still faster:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! I feel like the Python solution is somewhat cleaner. The time spent there is almost all in the
"".join
while the NumPy solution spends a significant amounts of time innp.argwhere
,2 **
andnp.sum
. I guess there's just more that's going on there. 🤔