Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
normanrz committed Jun 12, 2024
1 parent 14807a5 commit 9932a1d
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions numcodecs/zarr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numcodecs

from zarr.abc.codec import ArrayArrayCodec, BytesBytesCodec
from zarr.buffer import NDBuffer, Buffer, as_numpy_array_wrapper
from zarr.buffer import NDBuffer, Buffer, BufferPrototype, as_numpy_array_wrapper
from zarr.array_spec import ArraySpec
from zarr.common import (
JSON,
Expand Down Expand Up @@ -64,7 +64,6 @@ def __init__(

@cached_property
def _codec(self) -> numcodecs.abc.Codec:
print(self.codec_config)
return numcodecs.get_codec(self.codec_config)

@classmethod
Expand Down Expand Up @@ -92,20 +91,25 @@ def __init__(self, *, codec_id: str, codec_config: dict[str, JSON]) -> None:
super().__init__(codec_id=codec_id, codec_config=codec_config)

async def _decode_single(
self, chunk_bytes: Buffer, _chunk_spec: ArraySpec
self, chunk_bytes: Buffer, chunk_spec: ArraySpec
) -> Buffer:
return await to_thread(as_numpy_array_wrapper, self._codec.decode, chunk_bytes)
return await to_thread(
as_numpy_array_wrapper,
self._codec.decode,
chunk_bytes,
chunk_spec.prototype,
)

def _encode(self, chunk_bytes: Buffer) -> Buffer:
def _encode(self, chunk_bytes: Buffer, prototype: BufferPrototype) -> Buffer:
encoded = self._codec.encode(chunk_bytes.as_array_like())
if isinstance(encoded, np.ndarray): # Required for checksum codecs
return encoded.tobytes()
return Buffer.from_bytes(encoded)
return prototype.buffer.from_bytes(encoded.tobytes())
return prototype.buffer.from_bytes(encoded)

async def _encode_single(
self, chunk_bytes: Buffer, _chunk_spec: ArraySpec
self, chunk_bytes: Buffer, chunk_spec: ArraySpec
) -> Buffer:
return await to_thread(self._encode, chunk_bytes)
return await to_thread(self._encode, chunk_bytes, chunk_spec.prototype)


class NumcodecsArrayArrayCodec(NumcodecsCodec, ArrayArrayCodec):
Expand All @@ -117,14 +121,16 @@ async def _decode_single(
) -> NDBuffer:
chunk_ndarray = chunk_array.as_ndarray_like()
out = await to_thread(self._codec.decode, chunk_ndarray)
return NDBuffer.from_ndarray_like(out.reshape(chunk_spec.shape))
return chunk_spec.prototype.nd_buffer.from_ndarray_like(
out.reshape(chunk_spec.shape)
)

async def _encode_single(
self, chunk_array: NDBuffer, _chunk_spec: ArraySpec
self, chunk_array: NDBuffer, chunk_spec: ArraySpec
) -> NDBuffer:
chunk_ndarray = chunk_array.as_ndarray_like()
out = await to_thread(self._codec.encode, chunk_ndarray)
return NDBuffer.from_ndarray_like(out)
return chunk_spec.prototype.nd_buffer.from_ndarray_like(out)


def make_bytes_bytes_codec(
Expand Down

0 comments on commit 9932a1d

Please sign in to comment.