From fcf87c43681dde1efcb1f30570b27c6af1ecf4c9 Mon Sep 17 00:00:00 2001 From: Graeme Harris Date: Sun, 7 May 2023 16:19:20 +0200 Subject: [PATCH 1/2] Added cursor pagintion for flagged messages --- backend/oasst_backend/api/v1/admin.py | 72 +++++++++++++++++++++- backend/oasst_backend/prompt_repository.py | 49 +++++++++++++++ 2 files changed, 119 insertions(+), 2 deletions(-) diff --git a/backend/oasst_backend/api/v1/admin.py b/backend/oasst_backend/api/v1/admin.py index 4974782d38..4f1bed2d48 100644 --- a/backend/oasst_backend/api/v1/admin.py +++ b/backend/oasst_backend/api/v1/admin.py @@ -3,7 +3,7 @@ from uuid import UUID import pydantic -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from loguru import logger from oasst_backend.api import deps from oasst_backend.config import Settings, settings @@ -11,7 +11,9 @@ from oasst_backend.prompt_repository import PromptRepository from oasst_backend.tree_manager import TreeManager from oasst_backend.utils.database_utils import CommitMode, managed_tx_function -from oasst_shared.schemas.protocol import SystemStats +from oasst_shared import utils +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode +from oasst_shared.schemas.protocol import PageResult, SystemStats from oasst_shared.utils import ScopeTimer, unaware_to_utc router = APIRouter() @@ -171,6 +173,72 @@ class FlaggedMessageResponse(pydantic.BaseModel): created_date: Optional[datetime] +class FlaggedMessagePage(PageResult): + messages: list[FlaggedMessageResponse] + + +@router.get("/flagged_messages/cursor", response_model=FlaggedMessagePage) +def get_flagged_messages_cursor( + *, + before: Optional[str] = None, + after: Optional[str] = None, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client_id: Optional[str] = None, + only_roots: Optional[bool] = False, + include_deleted: Optional[bool] = False, + max_count: Optional[int] = Query(10, gt=0, le=1000), + desc: Optional[bool] = False, + lang: Optional[str] = None, + session: deps.Session = Depends(deps.get_db), + api_client: ApiClient = Depends(deps.get_trusted_api_client), +) -> str: + assert api_client.trusted + assert max_count is not None + + def split_cursor(x: str | None) -> tuple[datetime, UUID]: + if not x: + return None, None + try: + m = utils.split_uuid_pattern.match(x) + if m: + return datetime.fromisoformat(m[2]), UUID(m[1]) + return datetime.fromisoformat(x), None + except ValueError: + raise OasstError("Invalid cursor value", OasstErrorCode.INVALID_CURSOR_VALUE) + + if desc: + gte_created_date, gt_id = split_cursor(before) + lte_created_date, lt_id = split_cursor(after) + query_desc = not (before is not None and not after) + else: + lte_created_date, lt_id = split_cursor(before) + gte_created_date, gt_id = split_cursor(after) + query_desc = before is not None and not after + + logger.debug(f"{desc=} {query_desc=} {gte_created_date=} {lte_created_date=}") + + qry_max_count = max_count + 1 if before is None or after is None else max_count + + pr = PromptRepository(session, api_client) + items = pr.fetch_flagged_messages_by_created_date( + auth_method=auth_method, + username=username, + api_client_id=api_client_id, + gte_created_date=gte_created_date, + gt_id=gt_id, + lte_created_date=lte_created_date, + lt_id=lt_id, + only_roots=only_roots, + deleted=None if include_deleted else False, + desc=query_desc, + limit=qry_max_count, + lang=lang, + ) + # resp = [FlaggedMessageResponse(**msg.__dict__) for msg in flagged_messages] + # return FlaggedMessagePage(messages=resp, cursor=flagged_messages.cursor) + + @router.get("/flagged_messages", response_model=list[FlaggedMessageResponse]) async def get_flagged_messages( max_count: Optional[int], diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index d7b3ee4133..8b86d8f0b4 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -1266,6 +1266,55 @@ def fetch_flagged_messages(self, max_count: Optional[int]) -> list[FlaggedMessag return qry.all() + def fetch_flagged_messages_by_created_date( + self, + api_client_id: Optional[UUID] = None, + gte_created_date: Optional[datetime] = None, + gt_id: Optional[UUID] = None, + lte_created_date: Optional[datetime] = None, + lt_id: Optional[UUID] = None, + desc: bool = False, + limit: Optional[int] = 100, + lang: Optional[str] = None, + ) -> list[FlaggedMessage]: + qry = self.db.query(FlaggedMessage) + + if gte_created_date is not None: + if gt_id: + qry = qry.filter( + or_( + FlaggedMessage.created_date > gte_created_date, + and_(FlaggedMessage.created_date == gte_created_date, FlaggedMessage.message_id > gt_id), + ) + ) + else: + qry = qry.filter(FlaggedMessage.created_date >= gte_created_date) + elif gt_id: + raise OasstError("Need id and date for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if lte_created_date is not None: + if lt_id: + qry = qry.filter( + or_( + FlaggedMessage.created_date < lte_created_date, + and_(FlaggedMessage.created_date == lte_created_date, FlaggedMessage.message_id < lt_id), + ) + ) + else: + qry = qry.filter(FlaggedMessage.created_date <= lte_created_date) + elif lt_id: + raise OasstError("Need id and date for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if desc: + qry = qry.order_by(Message.created_date.desc(), Message.id.desc()) + else: + qry = qry.order_by(Message.created_date.asc(), Message.id.asc()) + + if limit is not None: + qry = qry.limit(limit) + + return qry.all() + def process_flagged_message(self, message_id: UUID) -> FlaggedMessage: message = self.db.query(FlaggedMessage).get(message_id) From 58fb9c4770f85d9676225cc4dfaceb46c5ca4675 Mon Sep 17 00:00:00 2001 From: Graeme Harris Date: Tue, 16 May 2023 15:15:50 +0200 Subject: [PATCH 2/2] Added paginated query for flagged messages --- backend/oasst_backend/api/v1/admin.py | 40 ++++++++++++++-------- backend/oasst_backend/prompt_repository.py | 6 ++-- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/backend/oasst_backend/api/v1/admin.py b/backend/oasst_backend/api/v1/admin.py index 4f1bed2d48..6dd593f5d6 100644 --- a/backend/oasst_backend/api/v1/admin.py +++ b/backend/oasst_backend/api/v1/admin.py @@ -174,7 +174,7 @@ class FlaggedMessageResponse(pydantic.BaseModel): class FlaggedMessagePage(PageResult): - messages: list[FlaggedMessageResponse] + items: list[FlaggedMessageResponse] @router.get("/flagged_messages/cursor", response_model=FlaggedMessagePage) @@ -182,14 +182,8 @@ def get_flagged_messages_cursor( *, before: Optional[str] = None, after: Optional[str] = None, - auth_method: Optional[str] = None, - username: Optional[str] = None, - api_client_id: Optional[str] = None, - only_roots: Optional[bool] = False, - include_deleted: Optional[bool] = False, max_count: Optional[int] = Query(10, gt=0, le=1000), desc: Optional[bool] = False, - lang: Optional[str] = None, session: deps.Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), ) -> str: @@ -222,21 +216,37 @@ def split_cursor(x: str | None) -> tuple[datetime, UUID]: pr = PromptRepository(session, api_client) items = pr.fetch_flagged_messages_by_created_date( - auth_method=auth_method, - username=username, - api_client_id=api_client_id, gte_created_date=gte_created_date, gt_id=gt_id, lte_created_date=lte_created_date, lt_id=lt_id, - only_roots=only_roots, - deleted=None if include_deleted else False, desc=query_desc, limit=qry_max_count, - lang=lang, ) - # resp = [FlaggedMessageResponse(**msg.__dict__) for msg in flagged_messages] - # return FlaggedMessagePage(messages=resp, cursor=flagged_messages.cursor) + + num_rows = len(items) + if qry_max_count > max_count and num_rows == qry_max_count: + assert not (before and after) + items = items[:-1] + + if desc != query_desc: + items.reverse() + + n, p = None, None + if len(items) > 0: + if (num_rows > max_count and before) or after: + p = str(items[0].message_id) + "$" + items[0].created_date.isoformat() + if num_rows > max_count or before: + n = str(items[-1].message_id) + "$" + items[-1].created_date.isoformat() + else: + if after: + p = lte_created_date.isoformat() if desc else gte_created_date.isoformat() + if before: + n = gte_created_date.isoformat() if desc else lte_created_date.isoformat() + + order = "desc" if desc else "asc" + print(p, n, items, order) + return FlaggedMessagePage(prev=p, next=n, sort_key="created_date", order=order, items=items) @router.get("/flagged_messages", response_model=list[FlaggedMessageResponse]) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 8b86d8f0b4..ab31f59c9f 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -1268,14 +1268,12 @@ def fetch_flagged_messages(self, max_count: Optional[int]) -> list[FlaggedMessag def fetch_flagged_messages_by_created_date( self, - api_client_id: Optional[UUID] = None, gte_created_date: Optional[datetime] = None, gt_id: Optional[UUID] = None, lte_created_date: Optional[datetime] = None, lt_id: Optional[UUID] = None, desc: bool = False, limit: Optional[int] = 100, - lang: Optional[str] = None, ) -> list[FlaggedMessage]: qry = self.db.query(FlaggedMessage) @@ -1306,9 +1304,9 @@ def fetch_flagged_messages_by_created_date( raise OasstError("Need id and date for keyset pagination", OasstErrorCode.GENERIC_ERROR) if desc: - qry = qry.order_by(Message.created_date.desc(), Message.id.desc()) + qry = qry.order_by(FlaggedMessage.created_date.desc(), FlaggedMessage.message_id.desc()) else: - qry = qry.order_by(Message.created_date.asc(), Message.id.asc()) + qry = qry.order_by(FlaggedMessage.created_date.asc(), FlaggedMessage.message_id.asc()) if limit is not None: qry = qry.limit(limit)