Skip to content

Commit

Permalink
Bitarray postselect (Qiskit#12693)
Browse files Browse the repository at this point in the history
* define BitArray.postselect()

* add test for BitArray.postselect()

* lint

* remove redundant docstring text

* Update qiskit/primitives/containers/bit_array.py

Co-authored-by: Ian Hincks <ian.hincks@gmail.com>

* docstring ticks (BitArray.postselect())

Co-authored-by: Ian Hincks <ian.hincks@gmail.com>

* Simpler tests for BitArray.postselect

* lint

* add release note

* check postselect() arg lengths match

* fix postselect tests

- fix bugs with checking that ValueError is raised.
- addtionally run all tests on a "flat" data input

* lint

* Fix type-hint

We immediately check the lengths of these args, so they should be Sequences, not Iterables.

* remove spurious print()

* lint

* lint

* use bitwise operations for faster postselect

- Also added support for negative indices
- Also updated tests

* remove spurious print()

* end final line of release note

* try to fix docstring formatting

* fix bitarray test assertion

Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com>

* disallow postselect positional kwarg

Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com>

* fix numpy dtype args

* Simpler kwarg: "assume_unique"

* lint (line too long)

* simplification: remove assume_unique kwarg

* improve misleading comment

* raise IndexError if indices out of range

- Change ValueError to IndexError.
- Add check for out-of-range negative indices.
- Simplify use of mod
- Update test conditions (include checks for off-by-one errors)

* lint

* add negative-contradiction test

* Update docstring with IndexErrors

* lint

* change slice_bits error from ValueError to IndexError

* update slice_bits test to use IndexError

* change ValueError to IndexError in slice_shots

also update tests for this error

* update error type in slice_shots docstring

* Revert ValueError to IndexError changes

Reverting these changes as they will instead be made in a separate PR.

This reverts commit 8f32178.

Revert "update error type in slice_shots docstring"

This reverts commit 50545ef.

Revert "change ValueError to IndexError in slice_shots"

This reverts commit c4becd9.

Revert "update slice_bits test to use IndexError"

This reverts commit c2b0039.

* fix docstring formatting

Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com>

* allow selection to be int instead of bool

* In tests, give selection as type int

* lint

* add example to release note

* fix typo in test case

* add check of test

Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com>

* lint

---------

Co-authored-by: Ian Hincks <ian.hincks@gmail.com>
Co-authored-by: Takashi Imamichi <31178928+t-imamichi@users.noreply.github.com>
  • Loading branch information
3 people authored and Procatv committed Aug 1, 2024
1 parent 131756a commit dc7722e
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 0 deletions.
91 changes: 91 additions & 0 deletions qiskit/primitives/containers/bit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,97 @@ def slice_shots(self, indices: int | Sequence[int]) -> "BitArray":
arr = arr[..., indices, :]
return BitArray(arr, self.num_bits)

def postselect(
self,
indices: Sequence[int] | int,
selection: Sequence[bool | int] | bool | int,
) -> BitArray:
"""Post-select this bit array based on sliced equality with a given bitstring.
.. note::
If this bit array contains any shape axes, it is first flattened into a long list of shots
before applying post-selection. This is done because :class:`~BitArray` cannot handle
ragged numbers of shots across axes.
Args:
indices: A list of the indices of the cbits on which to postselect.
If this bit array was produced by a sampler, then an index ``i`` corresponds to the
:class:`~.ClassicalRegister` location ``creg[i]`` (as in :meth:`~slice_bits`).
Negative indices are allowed.
selection: A list of binary values (will be cast to ``bool``) of length matching
``indices``, with ``indices[i]`` corresponding to ``selection[i]``. Shots will be
discarded unless all cbits specified by ``indices`` have the values given by
``selection``.
Returns:
A new bit array with ``shape=(), num_bits=data.num_bits, num_shots<=data.num_shots``.
Raises:
IndexError: If ``max(indices)`` is greater than or equal to :attr:`num_bits`.
IndexError: If ``min(indices)`` is less than negative :attr:`num_bits`.
ValueError: If the lengths of ``selection`` and ``indices`` do not match.
"""
if isinstance(indices, int):
indices = (indices,)
if isinstance(selection, (bool, int)):
selection = (selection,)
selection = np.asarray(selection, dtype=bool)

num_indices = len(indices)

if len(selection) != num_indices:
raise ValueError("Lengths of indices and selection do not match.")

num_bytes = self._array.shape[-1]
indices = np.asarray(indices)

if num_indices > 0:
if indices.max() >= self.num_bits:
raise IndexError(
f"index {int(indices.max())} out of bounds for the number of bits {self.num_bits}."
)
if indices.min() < -self.num_bits:
raise IndexError(
f"index {int(indices.min())} out of bounds for the number of bits {self.num_bits}."
)

flattened = self.reshape((), self.size * self.num_shots)

# If no conditions, keep all data, but flatten as promised:
if num_indices == 0:
return flattened

# Make negative bit indices positive:
indices %= self.num_bits

# Handle special-case of contradictory conditions:
if np.intersect1d(indices[selection], indices[np.logical_not(selection)]).size > 0:
return BitArray(np.empty((0, num_bytes), dtype=np.uint8), num_bits=self.num_bits)

# Recall that creg[0] is the LSb:
byte_significance, bit_significance = np.divmod(indices, 8)
# least-significant byte is at last position:
byte_idx = (num_bytes - 1) - byte_significance
# least-significant bit is at position 0:
bit_offset = bit_significance.astype(np.uint8)

# Get bitpacked representation of `indices` (bitmask):
bitmask = np.zeros(num_bytes, dtype=np.uint8)
np.bitwise_or.at(bitmask, byte_idx, np.uint8(1) << bit_offset)

# Get bitpacked representation of `selection` (desired bitstring):
selection_bytes = np.zeros(num_bytes, dtype=np.uint8)
## This assumes no contradictions present, since those were already checked for:
np.bitwise_or.at(
selection_bytes, byte_idx, np.asarray(selection, dtype=np.uint8) << bit_offset
)

return BitArray(
flattened._array[((flattened._array & bitmask) == selection_bytes).all(axis=-1)],
num_bits=self.num_bits,
)

def expectation_values(self, observables: ObservablesArrayLike) -> NDArray[np.float64]:
"""Compute the expectation values of the provided observables, broadcasted against
this bit array.
Expand Down
11 changes: 11 additions & 0 deletions releasenotes/notes/bitarray-postselect-659b8f7801ccaa60.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
features_primitives:
- |
Added a new method :meth:`.BitArray.postselect` that returns all shots containing specified bit values.
Example usage::
from qiskit.primitives.containers import BitArray
ba = BitArray.from_counts({'110': 2, '100': 4, '000': 3})
print(ba.postselect([0,2], [0,1]).get_counts())
# {'110': 2, '100': 4}
79 changes: 79 additions & 0 deletions test/python/primitives/containers/test_bit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,82 @@ def test_expectation_values(self):
_ = ba.expectation_values("Z")
with self.assertRaisesRegex(ValueError, "is not diagonal"):
_ = ba.expectation_values("X" * ba.num_bits)

def test_postselection(self):
"""Test the postselection method."""

flat_data = np.array(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
],
dtype=bool,
)

shaped_data = np.array(
[
[
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
],
[
[1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
],
]
],
dtype=bool,
)

for dataname, bool_array in zip(["flat", "shaped"], [flat_data, shaped_data]):

bit_array = BitArray.from_bool_array(bool_array, order="little")
# indices value of i <-> creg[i] <-> bool_array[..., i]

num_bits = bool_array.shape[-1]
bool_array = bool_array.reshape(-1, num_bits)

test_cases = [
("basic", [0, 1], [0, 0]),
("multibyte", [0, 9], [0, 1]),
("repeated", [5, 5, 5], [0, 0, 0]),
("contradict", [5, 5, 5], [1, 0, 0]),
("unsorted", [5, 0, 9, 3], [1, 0, 1, 0]),
("negative", [-5, 1, -2, -10], [1, 0, 1, 0]),
("negcontradict", [4, -6], [1, 0]),
("trivial", [], []),
("bareindex", 6, 0),
]

for name, indices, selection in test_cases:
with self.subTest("_".join([dataname, name])):
postselected_bools = np.unpackbits(
bit_array.postselect(indices, selection).array[:, ::-1],
count=num_bits,
axis=-1,
bitorder="little",
).astype(bool)
if isinstance(indices, int):
indices = (indices,)
if isinstance(selection, bool):
selection = (selection,)
answer = bool_array[np.all(bool_array[:, indices] == selection, axis=-1)]
if name in ["contradict", "negcontradict"]:
self.assertEqual(len(answer), 0)
else:
self.assertGreater(len(answer), 0)
np.testing.assert_equal(postselected_bools, answer)

error_cases = [
("aboverange", [0, 6, 10], [True, True, False], IndexError),
("belowrange", [0, 6, -11], [True, True, False], IndexError),
("mismatch", [0, 1, 2], [False, False], ValueError),
]
for name, indices, selection, error in error_cases:
with self.subTest(dataname + "_" + name):
with self.assertRaises(error):
bit_array.postselect(indices, selection)

0 comments on commit dc7722e

Please sign in to comment.