Skip to content

Commit

Permalink
feat(python): Ensure that buffer produced by `CBufferView.unpack_bits…
Browse files Browse the repository at this point in the history
…()` has a boolean type (#457)

This is small change to ensure that
`np.array(some_buffer.unpack_bits())` "just works" without nanoarrow
having to know about numpy dtypes. Basically we just need to ensure that
we can create/export a buffer with a `"?"` format string.

```python
import nanoarrow as na
import numpy as np

bool_array = na.Array([True, True, True, False, False, True], na.bool_())
np.array(bool_array.buffer(1).unpack_bits(0, len(bool_array)))
#> array([ True,  True,  True, False, False,  True])
```
  • Loading branch information
paleolimbot authored May 10, 2024
1 parent f47e830 commit 2f2450a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
15 changes: 12 additions & 3 deletions python/src/nanoarrow/_lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ cdef c_arrow_type_from_format(format):
return item_size, NANOARROW_TYPE_DOUBLE

# Check for signed integers
if format in ("b", "?", "h", "i", "l", "q", "n"):
if format in ("b", "h", "i", "l", "q", "n"):
if item_size == 1:
return item_size, NANOARROW_TYPE_INT8
elif item_size == 2:
Expand All @@ -346,7 +346,7 @@ cdef c_arrow_type_from_format(format):
return item_size, NANOARROW_TYPE_INT64

# Check for unsinged integers
if format in ("B", "H", "I", "L", "Q", "N"):
if format in ("B", "?", "H", "I", "L", "Q", "N"):
if item_size == 1:
return item_size, NANOARROW_TYPE_UINT8
elif item_size == 2:
Expand Down Expand Up @@ -1988,7 +1988,7 @@ cdef class CBufferView:
if length is None:
length = self.n_elements

out = CBufferBuilder().set_data_type(NANOARROW_TYPE_UINT8)
out = CBufferBuilder().set_format("?")
out.reserve_bytes(length)
self.unpack_bits_into(out, offset, length)
out.advance(length)
Expand Down Expand Up @@ -2108,6 +2108,8 @@ cdef class CBuffer:
self._device
)

snprintf(self._view._format, sizeof(self._view._format), "%s", self._format)

@staticmethod
def empty():
cdef CBuffer out = CBuffer()
Expand Down Expand Up @@ -2272,6 +2274,13 @@ cdef class CBufferBuilder:
self._buffer._set_data_type(type_id, element_size_bits)
return self

def set_format(self, str format):
"""Set the Python buffer format used to interpret elements in
:meth:`write_elements`.
"""
self._buffer._set_format(format)
return self

@property
def format(self):
"""The ``struct`` format code of the underlying buffer"""
Expand Down
1 change: 1 addition & 0 deletions python/tests/test_c_buffer_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_buffer_view_bool_unpack():
unpacked_all = view.unpack_bits()
assert len(unpacked_all) == view.n_elements
assert unpacked_all.data_type == "uint8"
assert unpacked_all.format == "?"
assert list(unpacked_all) == [1, 0, 0, 1, 0, 0, 0, 0]

unpacked_some = view.unpack_bits(1, 4)
Expand Down

0 comments on commit 2f2450a

Please sign in to comment.