Skip to content

Commit

Permalink
Refactor pagineated searched (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Dec 9, 2022
1 parent 24c8337 commit 6052f03
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 138 deletions.
10 changes: 9 additions & 1 deletion beaker/data_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from typing import (
Any,
Dict,
Generic,
Iterator,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Expand All @@ -24,7 +26,7 @@
logger = logging.getLogger("beaker")


__all__ = ["BaseModel", "MappedSequence", "StrEnum"]
__all__ = ["BaseModel", "MappedSequence", "StrEnum", "BasePage"]


BUG_REPORT_URL = (
Expand Down Expand Up @@ -149,3 +151,9 @@ def values(self):
class StrEnum(str, Enum):
def __str__(self) -> str:
return self.value


class BasePage(BaseModel, Generic[T]):
data: Tuple[T, ...]
next_cursor: Optional[str] = None
next: Optional[str] = None
6 changes: 2 additions & 4 deletions beaker/data_model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import validator

from .account import Account
from .base import BaseModel, StrEnum
from .base import BaseModel, BasePage, StrEnum
from .workspace import WorkspaceRef

__all__ = [
Expand Down Expand Up @@ -154,10 +154,8 @@ class DatasetManifest(BaseModel):
cursor: Optional[str] = None


class DatasetsPage(BaseModel):
class DatasetsPage(BasePage[Dataset]):
data: Tuple[Dataset, ...]
next_cursor: Optional[str] = None
next: Optional[str] = None


class DatasetSpec(BaseModel):
Expand Down
6 changes: 2 additions & 4 deletions beaker/data_model/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import Field

from .account import Account
from .base import BaseModel, MappedSequence, StrEnum
from .base import BaseModel, BasePage, MappedSequence, StrEnum
from .job import Job
from .workspace import WorkspaceRef

Expand Down Expand Up @@ -61,10 +61,8 @@ def __init__(self, tasks: List[Task]):
super().__init__(tasks, {task.name: task for task in tasks if task.name is not None})


class ExperimentsPage(BaseModel):
class ExperimentsPage(BasePage[Experiment]):
data: Tuple[Experiment, ...]
next_cursor: Optional[str] = None
next: Optional[str] = None


class ExperimentPatch(BaseModel):
Expand Down
6 changes: 2 additions & 4 deletions beaker/data_model/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Tuple

from .account import Account
from .base import BaseModel, StrEnum
from .base import BaseModel, BasePage, StrEnum
from .workspace import WorkspaceRef

__all__ = [
Expand Down Expand Up @@ -56,10 +56,8 @@ class GroupPatch(BaseModel):
parameters: Optional[List[GroupParameter]] = None


class GroupsPage(BaseModel):
class GroupsPage(BasePage[Group]):
data: Tuple[Group, ...]
next_cursor: Optional[str] = None
next: Optional[str] = None


class GroupSort(StrEnum):
Expand Down
6 changes: 2 additions & 4 deletions beaker/data_model/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import validator

from .account import Account
from .base import BaseModel, StrEnum
from .base import BaseModel, BasePage, StrEnum
from .workspace import WorkspaceRef

__all__ = [
Expand Down Expand Up @@ -51,10 +51,8 @@ def _validate_datetime(cls, v: Optional[datetime]) -> Optional[datetime]:
return v


class ImagesPage(BaseModel):
class ImagesPage(BasePage[Image]):
data: Tuple[Image, ...]
next_cursor: Optional[str] = None
next: Optional[str] = None


class ImageRepoAuth(BaseModel):
Expand Down
6 changes: 2 additions & 4 deletions beaker/data_model/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, List, Optional, Tuple

from .account import Account
from .base import BaseModel, StrEnum
from .base import BaseModel, BasePage, StrEnum

__all__ = [
"WorkspaceSize",
Expand Down Expand Up @@ -46,10 +46,8 @@ class WorkspaceRef(BaseModel):
full_name: str


class WorkspacePage(BaseModel):
class WorkspacePage(BasePage[Workspace]):
data: Tuple[Workspace, ...]
next_cursor: Optional[str] = None
next: Optional[str] = None
org: Optional[str] = None


Expand Down
186 changes: 69 additions & 117 deletions beaker/services/workspace.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections import defaultdict
from datetime import datetime
from typing import Dict, Generator, List, Optional, Union
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union

from ..data_model import *
from ..exceptions import *
from ..util import format_cursor
from .service_client import ServiceClient

T = TypeVar("T")


class WorkspaceClient(ServiceClient):
"""
Expand Down Expand Up @@ -211,6 +213,43 @@ def move(
},
)

def _paginated_requests(
self,
page_class: Type[BasePage[T]],
path: str,
query: Dict[str, Any],
limit: Optional[int] = None,
workspace_name: Optional[str] = None,
) -> Generator[T, None, None]:
if limit:
query["limit"] = str(limit)

exceptions_for_status: Optional[Dict[int, Exception]] = (
None
if workspace_name is None
else {404: WorkspaceNotFound(self._not_found_err_msg(workspace_name))}
)

count = 0
while True:
page = page_class.from_json(
self.request(
path,
method="GET",
query=query,
exceptions_for_status=exceptions_for_status,
).json()
)
for x in page.data:
count += 1
yield x
if limit is not None and count >= limit:
return

query["cursor"] = page.next_cursor or page.next
if not query["cursor"]:
break

def iter(
self,
org: Optional[Union[str, Organization]] = None,
Expand Down Expand Up @@ -238,27 +277,8 @@ def iter(
query["q"] = match
if archived is not None:
query["archived"] = str(archived).lower()
if limit:
query["limit"] = str(limit)

count = 0
while True:
page = WorkspacePage.from_json(
self.request(
"workspaces",
method="GET",
query=query,
).json()
)
for workspace in page.data:
count += 1
yield workspace
if limit is not None and count >= limit:
return

query["cursor"] = page.next_cursor or page.next # type: ignore
if not query["cursor"]:
break
yield from self._paginated_requests(WorkspacePage, "workspaces", query, limit=limit)

def list(
self,
Expand Down Expand Up @@ -342,31 +362,14 @@ def iter_images(
}
if match is not None:
query["q"] = match
if limit:
query["limit"] = str(limit)

count = 0
while True:
page = ImagesPage.from_json(
self.request(
f"workspaces/{self.url_quote(workspace_name)}/images",
method="GET",
query=query,
exceptions_for_status={
404: WorkspaceNotFound(self._not_found_err_msg(workspace_name))
},
).json()
)

for image in page.data:
count += 1
yield image
if limit is not None and count >= limit:
return

query["cursor"] = page.next_cursor or page.next # type: ignore
if not query["cursor"]:
break
yield from self._paginated_requests(
ImagesPage,
f"workspaces/{self.url_quote(workspace_name)}/images",
query,
limit=limit,
workspace_name=workspace_name,
)

def images(
self,
Expand Down Expand Up @@ -443,31 +446,14 @@ def iter_experiments(
}
if match is not None:
query["q"] = match
if limit:
query["limit"] = str(limit)

count = 0
while True:
page = ExperimentsPage.from_json(
self.request(
f"workspaces/{self.url_quote(workspace_name)}/experiments",
method="GET",
query=query,
exceptions_for_status={
404: WorkspaceNotFound(self._not_found_err_msg(workspace_name))
},
).json()
)

for experiment in page.data:
count += 1
yield experiment
if limit is not None and count >= limit:
return

query["cursor"] = page.next_cursor or page.next # type: ignore
if not query["cursor"]:
break
yield from self._paginated_requests(
ExperimentsPage,
f"workspaces/{self.url_quote(workspace_name)}/experiments",
query,
limit=limit,
workspace_name=workspace_name,
)

def experiments(
self,
Expand Down Expand Up @@ -552,31 +538,14 @@ def iter_datasets(
query["results"] = str(results).lower()
if uncommitted is not None:
query["committed"] = str(not uncommitted).lower()
if limit:
query["limit"] = str(limit)

count = 0
while True:
page = DatasetsPage.from_json(
self.request(
f"workspaces/{self.url_quote(workspace_name)}/datasets",
method="GET",
query=query,
exceptions_for_status={
404: WorkspaceNotFound(self._not_found_err_msg(workspace_name))
},
).json()
)

for dataset in page.data:
count += 1
yield dataset
if limit is not None and count >= limit:
return

query["cursor"] = page.next_cursor or page.next # type: ignore
if not query["cursor"]:
break
yield from self._paginated_requests(
DatasetsPage,
f"workspaces/{self.url_quote(workspace_name)}/datasets",
query,
limit=limit,
workspace_name=workspace_name,
)

def datasets(
self,
Expand Down Expand Up @@ -685,31 +654,14 @@ def iter_groups(
}
if match is not None:
query["q"] = match
if limit:
query["limit"] = str(limit)

count = 0
while True:
page = GroupsPage.from_json(
self.request(
f"workspaces/{self.url_quote(workspace_name)}/groups",
method="GET",
query=query,
exceptions_for_status={
404: WorkspaceNotFound(self._not_found_err_msg(workspace_name))
},
).json()
)

for group in page.data:
count += 1
yield group
if limit is not None and count >= limit:
return

query["cursor"] = page.next_cursor or page.next # type: ignore
if not query["cursor"]:
break
yield from self._paginated_requests(
GroupsPage,
f"workspaces/{self.url_quote(workspace_name)}/groups",
query,
limit=limit,
workspace_name=workspace_name,
)

def groups(
self,
Expand Down

0 comments on commit 6052f03

Please sign in to comment.