Skip to content

Commit

Permalink
type-golf-2
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan King committed Feb 2, 2024
1 parent 75bf0b9 commit d39557b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 33 deletions.
21 changes: 9 additions & 12 deletions hail/python/hailtop/aiocloud/aioaws/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def __str__(self) -> str:
return f's3://{self._bucket}/{self._path}'


class S3AsyncFS(AsyncFS):
class S3AsyncFS(AsyncFS[S3AsyncFSURL]):
def __init__(
self,
thread_pool: Optional[ThreadPoolExecutor] = None,
Expand Down Expand Up @@ -429,18 +429,18 @@ def get_bucket_and_name(url: str) -> Tuple[str, str]:

return (bucket, name)

async def open(self, url: str) -> ReadableStream:
bucket, name = self.get_bucket_and_name(url)
if name == '':
raise IsABucketError(url)
async def _open(self, url: S3AsyncFSURL) -> ReadableStream:
bucket = url._bucket
name = url._path
try:
resp = await blocking_to_async(self._thread_pool, self._s3.get_object, Bucket=bucket, Key=name)
return blocking_readable_stream_to_async(self._thread_pool, cast(BinaryIO, resp['Body']))
except self._s3.exceptions.NoSuchKey as e:
raise FileNotFoundError(url) from e

async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream:
bucket, name = self.get_bucket_and_name(url)
async def _open_from(self, url: S3AsyncFSURL, start: int, *, length: Optional[int] = None) -> ReadableStream:
bucket = url._bucket
name = url._path
range_str = f'bytes={start}-'
if length is not None:
assert length >= 1
Expand All @@ -457,7 +457,7 @@ async def _open_from(self, url: str, start: int, *, length: Optional[int] = None
raise UnexpectedEOFError from e
raise

async def create(self, url: str, *, retry_writes: bool = True) -> S3CreateManager: # pylint: disable=unused-argument
async def create(self, url: S3AsyncFSURL, *, retry_writes: bool = True) -> S3CreateManager: # pylint: disable=unused-argument
# It may be possible to write a more efficient version of this
# that takes advantage of retry_writes=False. Here's the
# background information:
Expand Down Expand Up @@ -498,10 +498,7 @@ async def create(self, url: str, *, retry_writes: bool = True) -> S3CreateManage
# interface. This has the disadvantage that the read must
# complete before the write can begin (unlike the current
# code, that copies 128MB parts in 256KB chunks).
bucket, name = self.get_bucket_and_name(url)
if name == '':
raise IsABucketError(url)
return S3CreateManager(self, bucket, name)
return S3CreateManager(self, url._bucket, url._path)

async def multi_part_create(self, sema: asyncio.Semaphore, url: str, num_parts: int) -> MultiPartCreate:
bucket, name = self.get_bucket_and_name(url)
Expand Down
47 changes: 34 additions & 13 deletions hail/python/hailtop/aiotools/fs/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import datetime
from hailtop.utils import retry_transient_errors, OnlineBoundedGather2
from .stream import EmptyReadableStream, ReadableStream, WritableStream
from .exceptions import FileAndDirectoryError
from .exceptions import FileAndDirectoryError, IsABucketError


T = TypeVar("T")
Expand Down Expand Up @@ -222,6 +222,10 @@ def with_path(self, path) -> "AsyncFSURL":
def with_root_path(self) -> "AsyncFSURL":
pass

@abc.abstractmethod
def is_bucket(self) -> bool:
pass

def with_new_path_component(self, new_path_component: str) -> "AsyncFSURL":
if new_path_component == '':
raise ValueError('new path component must be non-empty')
Expand All @@ -246,7 +250,10 @@ def __str__(self) -> str:
pass


class AsyncFS(abc.ABC):
URL = TypeVar('URL', bound=AsyncFSURL)


class AsyncFS(abc.ABC, Generic[URL]):
FILE = "file"
DIR = "dir"

Expand All @@ -268,22 +275,33 @@ def valid_url(url: str) -> bool:

@staticmethod
@abc.abstractmethod
def parse_url(url: str) -> AsyncFSURL:
def parse_url(url: str) -> URL:
pass

@classmethod
def _ensure_url_and_not_bucket(cls, url: Union[str, URL]) -> URL:
if isinstance(url, str):
url = cls.parse_url(url)
if url.is_bucket():
raise IsABucketError(str(url))
return url

async def open(self, url: Union[str, URL]) -> ReadableStream:
return await self._open(self._ensure_url_and_not_bucket(url))

@abc.abstractmethod
async def open(self, url: str) -> ReadableStream:
async def _open(self, url: URL) -> ReadableStream:
pass

async def open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream:
async def open_from(self, url: Union[str, URL], start: int, *, length: Optional[int] = None) -> ReadableStream:
url = self._ensure_url_and_not_bucket(url)
if length == 0:
fs_url = self.parse_url(url)
if fs_url.path.endswith("/"):
file_url = str(fs_url.with_path(fs_url.path.rstrip("/")))
dir_url = str(fs_url)
if url.path.endswith("/"):
file_url = str(url.with_path(url.path.rstrip("/")))
dir_url = str(url)
else:
file_url = str(fs_url)
dir_url = str(fs_url.with_path(fs_url.path + "/"))
file_url = str(url)
dir_url = str(url.with_path(url.path + "/"))
isfile, isdir = await asyncio.gather(self.isfile(file_url), self.isdir(dir_url))
if isfile:
if isdir:
Expand All @@ -295,11 +313,14 @@ async def open_from(self, url: str, start: int, *, length: Optional[int] = None)
return await self._open_from(url, start, length=length)

@abc.abstractmethod
async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream:
async def _open_from(self, url: URL, start: int, *, length: Optional[int] = None) -> ReadableStream:
pass

async def create(self, url: Union[str, URL], *, retry_writes: bool = True) -> AsyncContextManager[WritableStream]:
return await self._create(self._ensure_url_and_not_bucket(url), retry_writes=retry_writes)

@abc.abstractmethod
async def create(self, url: str, *, retry_writes: bool = True) -> AsyncContextManager[WritableStream]:
async def _create(self, url: URL, *, retry_writes: bool = True) -> AsyncContextManager[WritableStream]:
pass

@abc.abstractmethod
Expand Down
13 changes: 7 additions & 6 deletions hail/python/hailtop/aiotools/router_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from hailtop.config import ConfigVariable, configuration_of


class RouterAsyncFS(AsyncFS):
class RouterAsyncFS(AsyncFS[AsyncFSURL]):
FS_CLASSES: ClassVar[List[type[AsyncFS]]] = [
LocalAsyncFS,
aiogoogle.GoogleStorageAsyncFS,
Expand Down Expand Up @@ -72,7 +72,7 @@ def valid_url(url) -> bool:
or aioaws.S3AsyncFS.valid_url(url)
)

async def _get_fs(self, url: str):
async def _get_fs(self, url: AsyncFSURL) -> AsyncFS[AsyncFSURL]:
if LocalAsyncFS.valid_url(url):
if self._local_fs is None:
self._local_fs = LocalAsyncFS(**self._local_kwargs)
Expand All @@ -90,22 +90,23 @@ async def _get_fs(self, url: str):
self._azure_fs = aioazure.AzureAsyncFS(**self._azure_kwargs)
self._exit_stack.push_async_callback(self._azure_fs.close)
return self._azure_fs
if aioaws.S3AsyncFS.valid_url(url):
# if aioaws.S3AsyncFS.valid_url(url):
if isinstance(url, aioaws.S3AsyncFSURL):
if self._s3_fs is None:
self._s3_fs = aioaws.S3AsyncFS(**self._s3_kwargs)
self._exit_stack.push_async_callback(self._s3_fs.close)
return self._s3_fs
raise ValueError(f'no file system found for url {url}')

async def open(self, url: str) -> ReadableStream:
async def _open(self, url: AsyncFSURL) -> ReadableStream:
fs = await self._get_fs(url)
return await fs.open(url)

async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream:
async def _open_from(self, url: AsyncFSURL, start: int, *, length: Optional[int] = None) -> ReadableStream:
fs = await self._get_fs(url)
return await fs.open_from(url, start, length=length)

async def create(self, url: str, *, retry_writes: bool = True) -> AsyncContextManager[WritableStream]:
async def _create(self, url: AsyncFSURL, *, retry_writes: bool = True) -> AsyncContextManager[WritableStream]:
fs = await self._get_fs(url)
return await fs.create(url, retry_writes=retry_writes)

Expand Down
4 changes: 2 additions & 2 deletions hail/python/test/hailtop/inter_cloud/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, AsyncIterator, Dict
from typing import Tuple, AsyncIterator, Dict, Any
import secrets
import os
import asyncio
Expand All @@ -13,7 +13,7 @@


@pytest.fixture(scope='module')
async def router_filesystem() -> AsyncIterator[Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]]:
async def router_filesystem() -> AsyncIterator[Tuple[asyncio.Semaphore, AsyncFS[Any], Dict[str, str]]]:
token = secrets.token_hex(16)

async with RouterAsyncFS() as fs:
Expand Down

0 comments on commit d39557b

Please sign in to comment.