Skip to content

Commit

Permalink
Merge pull request #3 from maxrjones/check_writable
Browse files Browse the repository at this point in the history
Get most of the tests to pass
  • Loading branch information
kylebarron authored Dec 16, 2024
2 parents fb8b16d + 40e1b25 commit 3aa3578
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 12 deletions.
25 changes: 18 additions & 7 deletions src/zarr/storage/object_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypedDict
Expand Down Expand Up @@ -70,17 +71,22 @@ async def get(
return prototype.buffer.from_bytes(await resp.bytes_async())

start, end = byte_range
if (start is None or start == 0) and end is None:
resp = await obs.get_async(self.store, key)
return prototype.buffer.from_bytes(await resp.bytes_async())
if start is not None and end is not None:
resp = await obs.get_range_async(self.store, key, start=start, end=end)
return prototype.buffer.from_bytes(memoryview(resp))
elif start is not None:
if start >= 0:
if start > 0:
# Offset request
resp = await obs.get_async(self.store, key, options={"range": {"offset": start}})
else:
resp = await obs.get_async(self.store, key, options={"range": {"suffix": start}})

return prototype.buffer.from_bytes(await resp.bytes_async())
elif end is not None:
resp = await obs.get_range_async(self.store, key, start=0, end=end)
return prototype.buffer.from_bytes(memoryview(resp))
else:
raise ValueError(f"Unexpected input to `get`: {start=}, {end=}")

Expand All @@ -104,18 +110,22 @@ def supports_writes(self) -> bool:
return True

async def set(self, key: str, value: Buffer) -> None:
self._check_writable()
buf = value.to_bytes()
await obs.put_async(self.store, key, buf)

async def set_if_not_exists(self, key: str, value: Buffer) -> None:
self._check_writable()
buf = value.to_bytes()
await obs.put_async(self.store, key, buf, mode="create")
with contextlib.suppress(obs.exceptions.AlreadyExistsError):
await obs.put_async(self.store, key, buf, mode="create")

@property
def supports_deletes(self) -> bool:
return True

async def delete(self, key: str) -> None:
self._check_writable()
await obs.delete_async(self.store, key)

@property
Expand Down Expand Up @@ -158,12 +168,13 @@ async def _transform_list_dir(
# We assume that the underlying object-store implementation correctly handles the
# prefix, so we don't double-check that the returned results actually start with the
# given prefix.
prefix_len = len(prefix)
prefix_len = len(prefix) + 1 # If one is not added to the length, all items will contain "/"
async for batch in list_stream:
for item in batch:
# Yield this item if "/" does not exist after the prefix.
if "/" not in item["path"][prefix_len:]:
yield item["path"]
# Yield this item if "/" does not exist after the prefix
item_path = item["path"][prefix_len:]
if "/" not in item_path:
yield item_path


class _BoundedRequest(TypedDict):
Expand Down
8 changes: 4 additions & 4 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ class StoreTests(Generic[S, B]):
async def set(self, store: S, key: str, value: Buffer) -> None:
"""
Insert a value into a storage backend, with a specific key.
This should not not use any store methods. Bypassing the store methods allows them to be
This should not use any store methods. Bypassing the store methods allows them to be
tested.
"""
raise NotImplementedError

async def get(self, store: S, key: str) -> Buffer:
"""
Retrieve a value from a storage backend, by key.
This should not not use any store methods. Bypassing the store methods allows them to be
This should not use any store methods. Bypassing the store methods allows them to be
tested.
"""

Expand Down Expand Up @@ -103,14 +103,14 @@ def test_store_supports_listing(self, store: S) -> None:
raise NotImplementedError

@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
@pytest.mark.parametrize("byte_range", [None, (0, None), (1, None), (1, 2), (None, 1)])
async def test_get(
self, store: S, key: str, data: bytes, byte_range: None | tuple[int | None, int | None]
self, store: S, key: str, byte_range: None | tuple[int | None, int | None]
) -> None:
"""
Ensure that data can be read from the store using the store.get method.
"""
data = b"\x01\x02\x03\x04"
data_buf = self.buffer_cls.from_bytes(data)
await self.set(store, key, data_buf)
observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range)
Expand Down
33 changes: 32 additions & 1 deletion tests/test_store/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@

obstore = pytest.importorskip("obstore")

from zarr.core.buffer import cpu
import re

from zarr.core.buffer import Buffer, cpu
from zarr.storage.object_store import ObjectStore
from zarr.testing.store import StoreTests

PATTERN = r"file://(/[\w/.-]+)"


class TestObjectStore(StoreTests[ObjectStore, cpu.Buffer]):
store_cls = ObjectStore
Expand All @@ -20,3 +24,30 @@ def store_kwargs(self, tmpdir) -> dict[str, str | bool]:
@pytest.fixture
def store(self, store_kwargs: dict[str, str | bool]) -> ObjectStore:
return self.store_cls(**store_kwargs)

async def get(self, store: ObjectStore, key: str) -> Buffer:
# TODO: There must be a better way to get the path to the store
store_path = re.search(PATTERN, str(store)).group(1)
new_local_store = obstore.store.LocalStore(prefix=store_path)
return self.buffer_cls.from_bytes(obstore.get(new_local_store, key).bytes())

async def set(self, store: ObjectStore, key: str, value: Buffer) -> None:
# TODO: There must be a better way to get the path to the store
store_path = re.search(PATTERN, str(store)).group(1)
new_local_store = obstore.store.LocalStore(prefix=store_path)
obstore.put(new_local_store, key, value.to_bytes())

def test_store_repr(self, store: ObjectStore) -> None:
from fnmatch import fnmatch

pattern = "ObjectStore(object://LocalStore(file:///*))"
assert fnmatch(f"{store!r}", pattern)

def test_store_supports_writes(self, store: ObjectStore) -> None:
assert store.supports_writes

def test_store_supports_partial_writes(self, store: ObjectStore) -> None:
assert not store.supports_partial_writes

def test_store_supports_listing(self, store: ObjectStore) -> None:
assert store.supports_listing

0 comments on commit 3aa3578

Please sign in to comment.