diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index ed657f711a..b34c054563 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -40,6 +40,12 @@ These are the section headers that we use: - [breaking] Change `GET /api/v1/me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) - Change search index mapping for responses (reindex is required). ([#5228](https://github.com/argilla-io/argilla/pull/5228)) +### Changed + +- Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) +- [breaking] Change `GET /api/v1/datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) +- [breaking] Change `GET /api/v1/me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) + ### Fixed - Fixed SQLite connection settings not working correctly due to an outdated conditional. ([#5149](https://github.com/argilla-io/argilla/pull/5149)) @@ -48,6 +54,15 @@ These are the section headers that we use: ### Removed +- [breaking] Remove deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) +- [breaking] Remove deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) +- [breaking] Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) +- [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) +- [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) +- [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166)) + +## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) + - [breaking] Removed deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) - [breaking] Removed deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) - [breaking] Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 0fca256da4..ec5d811b3c 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -12,53 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Optional from uuid import UUID from fastapi import APIRouter, Depends, Query, Security, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -import argilla_server.search_engine as search_engine -from argilla_server.api.policies.v1 import DatasetPolicy, RecordPolicy, authorize, is_authorized +from argilla_server.api.policies.v1 import DatasetPolicy, authorize from argilla_server.api.schemas.v1.records import ( - Filters, - FilterScope, - MetadataFilterScope, - Order, - RangeFilter, - RecordFilterScope, RecordIncludeParam, Records, - SearchRecord, SearchRecordsQuery, SearchRecordsResult, - TermsFilter, ) -from argilla_server.api.schemas.v1.records import Record as RecordSchema -from argilla_server.api.schemas.v1.responses import ResponseFilterScope from argilla_server.api.schemas.v1.suggestions import ( SearchSuggestionOptions, SearchSuggestionOptionsQuestion, SearchSuggestionsOptions, - SuggestionFilterScope, ) from argilla_server.contexts import datasets, search from argilla_server.database import get_async_db -from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder -from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError -from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE +from argilla_server.enums import RecordSortField +from argilla_server.errors.future import UnprocessableEntityError +from argilla_server.repositories import DatasetsRepository, RecordsRepository from argilla_server.models import Dataset, Field, Record, User, VectorSettings from argilla_server.search_engine import ( - AndFilter, SearchEngine, - SearchResponses, - UserResponseStatusFilter, get_search_engine, ) from argilla_server.security import auth -from argilla_server.telemetry import TelemetryClient, get_telemetry_client +from argilla_server.services.search import SearchService from argilla_server.utils import parse_query_param, parse_uuids LIST_DATASET_RECORDS_LIMIT_DEFAULT = 50 @@ -73,198 +57,35 @@ router = APIRouter() -async def _filter_records_using_search_engine( - db: "AsyncSession", - search_engine: "SearchEngine", - dataset: Dataset, - limit: int, - offset: int, - user: Optional[User] = None, - include: Optional[RecordIncludeParam] = None, -) -> Tuple[List[Record], int]: - search_responses = await _get_search_responses( - db=db, - search_engine=search_engine, - dataset=dataset, - limit=limit, - offset=offset, - user=user, - ) - - record_ids = [response.record_id for response in search_responses.items] - user_id = user.id if user else None - - return ( - await datasets.get_records_by_ids( - db=db, dataset_id=dataset.id, user_id=user_id, records_ids=record_ids, include=include - ), - search_responses.total, - ) - - -def _to_search_engine_filter_scope(scope: FilterScope, user: Optional[User]) -> search_engine.FilterScope: - if isinstance(scope, RecordFilterScope): - return search_engine.RecordFilterScope(property=scope.property) - elif isinstance(scope, MetadataFilterScope): - return search_engine.MetadataFilterScope(metadata_property=scope.metadata_property) - elif isinstance(scope, SuggestionFilterScope): - return search_engine.SuggestionFilterScope(question=scope.question, property=scope.property) - elif isinstance(scope, ResponseFilterScope): - return search_engine.ResponseFilterScope(question=scope.question, property=scope.property, user=user) - else: - raise Exception(f"Unknown scope type {type(scope)}") - - -def _to_search_engine_filter(filters: Filters, user: Optional[User]) -> search_engine.Filter: - engine_filters = [] - - for filter in filters.and_: - engine_scope = _to_search_engine_filter_scope(filter.scope, user=user) - - if isinstance(filter, TermsFilter): - engine_filter = search_engine.TermsFilter(scope=engine_scope, values=filter.values) - elif isinstance(filter, RangeFilter): - engine_filter = search_engine.RangeFilter(scope=engine_scope, ge=filter.ge, le=filter.le) - else: - raise Exception(f"Unknown filter type {type(filter)}") - - engine_filters.append(engine_filter) - - return AndFilter(filters=engine_filters) - - -def _to_search_engine_sort(sort: List[Order], user: Optional[User]) -> List[search_engine.Order]: - engine_sort = [] - - for order in sort: - engine_scope = _to_search_engine_filter_scope(order.scope, user=user) - engine_sort.append(search_engine.Order(scope=engine_scope, order=order.order)) - - return engine_sort - - -async def _get_search_responses( - db: "AsyncSession", - search_engine: "SearchEngine", - dataset: Dataset, - limit: int, - offset: int, - search_records_query: Optional[SearchRecordsQuery] = None, - user: Optional[User] = None, -) -> "SearchResponses": - search_records_query = search_records_query or SearchRecordsQuery() - - text_query = None - vector_query = None - if search_records_query.query: - text_query = search_records_query.query.text - vector_query = search_records_query.query.vector - - filters = search_records_query.filters - sort = search_records_query.sort - - vector_settings = None - record = None - - if vector_query: - vector_settings = await VectorSettings.get_by(db, name=vector_query.name, dataset_id=dataset.id) - if vector_settings is None: - raise UnprocessableEntityError(f"Vector `{vector_query.name}` not found in dataset `{dataset.id}`.") - - if vector_query.record_id is not None: - record = await Record.get_by(db, id=vector_query.record_id, dataset_id=dataset.id) - if record is None: - raise UnprocessableEntityError( - f"Record with id `{vector_query.record_id}` not found in dataset `{dataset.id}`." - ) - - await record.awaitable_attrs.vectors - - if not record.vector_value_by_vector_settings(vector_settings): - # TODO: Once we move to v2.0 we can use here UnprocessableEntityError instead of MissingVectorError - raise MissingVectorError( - message=f"Record `{record.id}` does not have a vector for vector settings `{vector_settings.name}`", - code=MISSING_VECTOR_ERROR_CODE, - ) - - if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id): - raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.") - - if vector_query and vector_settings: - similarity_search_params = { - "dataset": dataset, - "vector_settings": vector_settings, - "value": vector_query.value, - "record": record, - "query": text_query, - "order": vector_query.order, - "max_results": limit, - } - - if filters: - similarity_search_params["filter"] = _to_search_engine_filter(filters, user=user) - - return await search_engine.similarity_search(**similarity_search_params) - else: - search_params = { - "dataset": dataset, - "query": text_query, - "offset": offset, - "limit": limit, - } - - if user is not None: - search_params["user_id"] = user.id - - if filters: - search_params["filter"] = _to_search_engine_filter(filters, user=user) - if sort: - search_params["sort"] = _to_search_engine_sort(sort, user=user) - - return await search_engine.search(**search_params) - - -async def _build_response_status_filter_for_search( - response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None -) -> Optional[UserResponseStatusFilter]: - user_response_status_filter = None - - if response_statuses: - # TODO(@frascuchon): user response and status responses should be split into different filter types - user_response_status_filter = UserResponseStatusFilter(user=user, statuses=response_statuses) - - return user_response_status_filter - - -async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID): - try: - await search.validate_search_records_query(db, query, dataset_id) - except (ValueError, NotFoundError) as e: - raise UnprocessableEntityError(str(e)) - - @router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) async def list_dataset_records( *, - db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), + datasets_repository: DatasetsRepository = Depends(), + records_repository: RecordsRepository = Depends(), dataset_id: UUID, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), ): - dataset = await Dataset.get_or_raise(db, dataset_id) - + dataset = await datasets_repository.get(dataset_id) await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset)) - records, total = await _filter_records_using_search_engine( - db, - search_engine, - dataset=dataset, - limit=limit, + include_args = ( + dict( + with_responses=include.with_responses, + with_suggestions=include.with_suggestions, + with_vectors=include.with_all_vectors or include.vectors, + ) + if include + else {} + ) + + records, total = await records_repository.list_by_dataset_id( + dataset_id=dataset.id, offset=offset, - include=include, + limit=limit, + **include_args, ) return Records(items=records, total=total) @@ -303,9 +124,9 @@ async def delete_dataset_records( ) async def search_current_user_dataset_records( *, + datasets: DatasetsRepository = Depends(), db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), - telemetry_client: TelemetryClient = Depends(get_telemetry_client), + engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, body: SearchRecordsQuery, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), @@ -313,54 +134,24 @@ async def search_current_user_dataset_records( limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), ): - dataset = await Dataset.get_or_raise( - db, - dataset_id, - options=[ - selectinload(Dataset.fields), - selectinload(Dataset.metadata_properties), - ], - ) - + dataset = await datasets.get(dataset_id) await authorize(current_user, DatasetPolicy.search_records(dataset)) - await _validate_search_records_query(db, body, dataset_id) - - search_responses = await _get_search_responses( + search_service = SearchService( db=db, - search_engine=search_engine, - dataset=dataset, - search_records_query=body, - limit=limit, - offset=offset, - user=current_user, + engine=engine, + records=RecordsRepository(db), + datasets=DatasetsRepository(db), ) - record_id_score_map: Dict[UUID, Dict[str, Union[float, SearchRecord, None]]] = { - response.record_id: {"query_score": response.score, "search_record": None} - for response in search_responses.items - } - - records = await datasets.get_records_by_ids( - db=db, - dataset_id=dataset_id, - records_ids=list(record_id_score_map.keys()), + return await search_service.search_records( + user=current_user, + dataset=dataset, + search_query=body, + offset=offset, + limit=limit, include=include, - user_id=current_user.id, - ) - - for record in records: - record.dataset = dataset - record.metadata_ = await _filter_record_metadata_for_user(record, current_user) - - record_id_score_map[record.id]["search_record"] = SearchRecord( - record=RecordSchema.from_orm(record), - query_score=record_id_score_map[record.id]["query_score"], - ) - - return SearchRecordsResult( - items=[record["search_record"] for record in record_id_score_map.values()], - total=search_responses.total, + search_bounded_to_user=True, ) @@ -373,7 +164,7 @@ async def search_current_user_dataset_records( async def search_dataset_records( *, db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), + engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, body: SearchRecordsQuery, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), @@ -381,44 +172,22 @@ async def search_dataset_records( limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), ): - dataset = await Dataset.get_or_raise(db, dataset_id, options=[selectinload(Dataset.fields)]) + dataset_repository = DatasetsRepository(db) + dataset = await dataset_repository.get(dataset_id) await authorize(current_user, DatasetPolicy.search_records_with_all_responses(dataset)) - await _validate_search_records_query(db, body, dataset_id) + search_service = SearchService(db=db, engine=engine, records=RecordsRepository(db), datasets=dataset_repository) - search_responses = await _get_search_responses( - db=db, - search_engine=search_engine, + return await search_service.search_records( + user=current_user, dataset=dataset, - search_records_query=body, - limit=limit, + search_query=body, offset=offset, - ) - - record_id_score_map = { - response.record_id: {"query_score": response.score, "search_record": None} - for response in search_responses.items - } - - records = await datasets.get_records_by_ids( - db=db, - dataset_id=dataset_id, - records_ids=list(record_id_score_map.keys()), + limit=limit, include=include, ) - for record in records: - record_id_score_map[record.id]["search_record"] = SearchRecord( - record=RecordSchema.from_orm(record), - query_score=record_id_score_map[record.id]["query_score"], - ) - - return SearchRecordsResult( - items=[record["search_record"] for record in record_id_score_map.values()], - total=search_responses.total, - ) - @router.get( "/datasets/{dataset_id}/records/search/suggestions/options", @@ -446,14 +215,3 @@ async def list_dataset_records_search_suggestions_options( for sa in suggestion_agents_by_question ] ) - - -async def _filter_record_metadata_for_user(record: Record, user: User) -> Optional[Dict[str, Any]]: - if record.metadata_ is None: - return None - - metadata = {} - for metadata_name in list(record.metadata_.keys()): - if await is_authorized(user, RecordPolicy.get_metadata(record, metadata_name)): - metadata[metadata_name] = record.metadata_[metadata_name] - return metadata diff --git a/argilla-server/src/argilla_server/api/schemas/v1/records.py b/argilla-server/src/argilla_server/api/schemas/v1/records.py index b5ff7c3f4c..afe4c48adf 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/records.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/records.py @@ -194,7 +194,9 @@ def _has_relationships(self): class RecordFilterScope(BaseModel): entity: Literal["record"] - property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at], Literal["status"]] + property: Union[ + Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at], Literal[RecordSortField.status] + ] class Records(BaseModel): @@ -298,6 +300,16 @@ class SearchRecordsQuery(BaseModel): None, min_items=SEARCH_RECORDS_QUERY_SORT_MIN_ITEMS, max_items=SEARCH_RECORDS_QUERY_SORT_MAX_ITEMS ) + @property + def text_query(self) -> Optional[TextQuery]: + if self.query: + return self.query.text + + @property + def vector_query(self) -> Optional[VectorQuery]: + if self.query: + return self.query.vector + class SearchRecord(BaseModel): record: Record diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 2ca212855b..505d3423ad 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -311,6 +311,7 @@ async def create_vector_settings( return vector_settings +# TODO: Remove this function when using the equivalent repository method async def get_records_by_ids( db: AsyncSession, records_ids: Iterable[UUID], diff --git a/argilla-server/src/argilla_server/contexts/search.py b/argilla-server/src/argilla_server/contexts/search.py index 8d10b891c8..797214e318 100644 --- a/argilla-server/src/argilla_server/contexts/search.py +++ b/argilla-server/src/argilla_server/contexts/search.py @@ -19,58 +19,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from argilla_server.api.schemas.v1.records import ( - FilterScope, - MetadataFilterScope, - RecordFilterScope, SearchRecordsQuery, ) -from argilla_server.api.schemas.v1.responses import ResponseFilterScope -from argilla_server.api.schemas.v1.suggestions import SuggestionFilterScope -from argilla_server.models import MetadataProperty, Question, Suggestion - - -class SearchRecordsQueryValidator: - def __init__(self, db: AsyncSession, query: SearchRecordsQuery, dataset_id: UUID): - self._db = db - self._query = query - self._dataset_id = dataset_id - - async def validate(self) -> None: - if self._query.filters: - for filter in self._query.filters.and_: - await self._validate_filter_scope(filter.scope) - - if self._query.sort: - for order in self._query.sort: - await self._validate_filter_scope(order.scope) - - async def _validate_filter_scope(self, filter_scope: FilterScope) -> None: - if isinstance(filter_scope, RecordFilterScope): - return - elif isinstance(filter_scope, ResponseFilterScope): - await self._validate_response_filter_scope(filter_scope) - elif isinstance(filter_scope, SuggestionFilterScope): - await self._validate_suggestion_filter_scope(filter_scope) - elif isinstance(filter_scope, MetadataFilterScope): - await self._validate_metadata_filter_scope(filter_scope) - else: - raise ValueError(f"Unknown filter scope entity `{filter_scope.entity}`") - - async def _validate_response_filter_scope(self, filter_scope: ResponseFilterScope) -> None: - if filter_scope.question is None: - return - - await Question.get_by_or_raise(self._db, name=filter_scope.question, dataset_id=self._dataset_id) - - async def _validate_suggestion_filter_scope(self, filter_scope: SuggestionFilterScope) -> None: - await Question.get_by_or_raise(self._db, name=filter_scope.question, dataset_id=self._dataset_id) - - async def _validate_metadata_filter_scope(self, filter_scope: MetadataFilterScope) -> None: - await MetadataProperty.get_by_or_raise( - self._db, - name=filter_scope.metadata_property, - dataset_id=self._dataset_id, - ) +from argilla_server.models import Question, Suggestion +from argilla_server.validators.search import SearchRecordsQueryValidator async def validate_search_records_query(db: AsyncSession, query: SearchRecordsQuery, dataset_id: UUID) -> None: diff --git a/argilla-server/src/argilla_server/enums.py b/argilla-server/src/argilla_server/enums.py index 2edc53d28f..b34e2e97bf 100644 --- a/argilla-server/src/argilla_server/enums.py +++ b/argilla-server/src/argilla_server/enums.py @@ -81,6 +81,7 @@ class MetadataPropertyType(str, Enum): class RecordSortField(str, Enum): inserted_at = "inserted_at" updated_at = "updated_at" + status = "status" class SortOrder(str, Enum): diff --git a/argilla-server/src/argilla_server/repositories/__init__.py b/argilla-server/src/argilla_server/repositories/__init__.py new file mode 100644 index 0000000000..98424d94a6 --- /dev/null +++ b/argilla-server/src/argilla_server/repositories/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argilla_server.repositories.datasets import DatasetsRepository +from argilla_server.repositories.records import RecordsRepository + +__all__ = ["DatasetsRepository", "RecordsRepository"] diff --git a/argilla-server/src/argilla_server/repositories/datasets.py b/argilla-server/src/argilla_server/repositories/datasets.py new file mode 100644 index 0000000000..d03b893a31 --- /dev/null +++ b/argilla-server/src/argilla_server/repositories/datasets.py @@ -0,0 +1,29 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from fastapi import Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.database import get_async_db +from argilla_server.models import Dataset + + +class DatasetsRepository: + def __init__(self, db: AsyncSession = Depends(get_async_db)): + self.db = db + + async def get(self, dataset_id: UUID, options: list = None) -> Dataset: + return await Dataset.get_or_raise(db=self.db, id=dataset_id, options=options or []) diff --git a/argilla-server/src/argilla_server/repositories/records.py b/argilla-server/src/argilla_server/repositories/records.py new file mode 100644 index 0000000000..a3d8839898 --- /dev/null +++ b/argilla-server/src/argilla_server/repositories/records.py @@ -0,0 +1,90 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List, Tuple, Sequence, Optional +from uuid import UUID + +from fastapi import Depends +from sqlalchemy import select, and_, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload, contains_eager + +from argilla_server.database import get_async_db +from argilla_server.models import Record, VectorSettings, Vector, Response + + +class RecordsRepository: + def __init__( + self, + db: AsyncSession = Depends(get_async_db), + ): + self.db = db + + async def list_by_dataset_id( + self, + dataset_id: UUID, + offset: int, + limit: int, + with_responses: bool = False, + with_suggestions: bool = False, + with_vectors: Union[bool, List[str]] = False, + ) -> Tuple[Sequence[Record], int]: + query = select(Record).filter_by(dataset_id=dataset_id) + query = await self._configure_record_load_relationships(dataset_id, query, with_suggestions, with_vectors) + + if with_responses: + query = query.options(selectinload(Record.responses)) + + records = (await self.db.scalars(query.offset(offset).limit(limit).order_by(Record.inserted_at))).unique().all() + total = await self.db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id)) + + return records, total + + async def list_by_dataset_id_and_ids( + self, + dataset_id: UUID, + ids: List[UUID], + user_id: Optional[UUID] = None, + with_responses: bool = False, + with_suggestions: bool = False, + with_vectors: Union[bool, List[str]] = False, + ): + query = select(Record).filter_by(dataset_id=dataset_id).filter(Record.id.in_(ids)) + query = await self._configure_record_load_relationships(dataset_id, query, with_suggestions, with_vectors) + + if with_responses: + if user_id: + query = query.outerjoin( + Response, and_(Response.record_id == Record.id, Response.user_id == user_id) + ).options(contains_eager(Record.responses)) + else: + query = query.options(selectinload(Record.responses)) + + return (await self.db.scalars(query)).unique().all() + + async def _configure_record_load_relationships(self, dataset_id, query, with_suggestions, with_vectors): + if with_suggestions: + query = query.options(selectinload(Record.suggestions)) + + if with_vectors is True: + query = query.options(selectinload(Record.vectors)) + elif isinstance(with_vectors, list): + subquery = select(VectorSettings.id).filter( + and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors)) + ) + query = query.outerjoin( + Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery)) + ).options(contains_eager(Record.vectors)) + + return query diff --git a/argilla-server/src/argilla_server/services/__init__.py b/argilla-server/src/argilla_server/services/__init__.py new file mode 100644 index 0000000000..4b6cecae7f --- /dev/null +++ b/argilla-server/src/argilla_server/services/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/argilla-server/src/argilla_server/services/search.py b/argilla-server/src/argilla_server/services/search.py new file mode 100644 index 0000000000..2f6b21423f --- /dev/null +++ b/argilla-server/src/argilla_server/services/search.py @@ -0,0 +1,216 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Any, Dict, Optional, List, Sequence + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +import argilla_server.search_engine as search_engine +from argilla_server.api.policies.v1 import is_authorized, RecordPolicy +from argilla_server.api.schemas.v1.records import ( + MetadataFilterScope, + RangeFilter, + RecordFilterScope, + SearchRecord, + SearchRecordsQuery, + SearchRecordsResult, + FilterScope, + Filters, + TermsFilter, + Order, + RecordIncludeParam, +) +from argilla_server.api.schemas.v1.responses import ResponseFilterScope +from argilla_server.api.schemas.v1.suggestions import ( + SuggestionFilterScope, +) +from argilla_server.models import Dataset, User, VectorSettings, Record +from argilla_server.repositories import RecordsRepository, DatasetsRepository +from argilla_server.search_engine import ( + AndFilter, + SearchEngine, + SearchResponses, +) +from argilla_server.validators.search import SearchRecordsQueryValidator + + +class SearchService: + def __init__( + self, db: AsyncSession, engine: SearchEngine, records: RecordsRepository, datasets: DatasetsRepository + ): + self.db = db + self.engine = engine + self.records = records + self.datasets = datasets + + async def search_records( + self, + user: User, + dataset: Dataset, + search_query: SearchRecordsQuery, + offset: int, + limit: int, + include: Optional[RecordIncludeParam] = None, + search_bounded_to_user: bool = False, + ) -> Any: + dataset = await self.datasets.get( + dataset.id, + options=[selectinload(Dataset.fields), selectinload(Dataset.metadata_properties)], + ) + + search_query = search_query or SearchRecordsQuery() + await SearchRecordsQueryValidator(self.db, search_query, dataset.id).validate() + + if search_query.vector_query: + results = await self._similarity_search( + dataset=dataset, search_query=search_query, user=user, max_results=limit + ) + else: + results = await self._search( + dataset=dataset, + search_query=search_query, + user=user if search_bounded_to_user else None, + offset=offset, + limit=limit, + ) + + include = include or RecordIncludeParam() + + records = await self.records.list_by_dataset_id_and_ids( + ids=[r.record_id for r in results.items], + dataset_id=dataset.id, + user_id=user.id if search_bounded_to_user else None, + with_responses=include.with_responses, + with_suggestions=include.with_suggestions, + with_vectors=include.with_all_vectors or include.vectors, + ) + await self._filter_records_metadata_for_user(records, user) + + records_by_id = {record.id: record for record in records} + return SearchRecordsResult( + total=results.total, + items=[ + SearchRecord(record=records_by_id[response.record_id], query_score=response.score) + for response in results.items + ], + ) + + async def _similarity_search( + self, dataset: Dataset, search_query: SearchRecordsQuery, user: User, max_results: int + ) -> SearchResponses: + filters = self._to_search_engine_filter(search_query.filters, user=user) + + text_query = search_query.text_query + vector_query = search_query.vector_query + + vector_settings = await VectorSettings.get_by(self.db, name=vector_query.name, dataset_id=dataset.id) + record = (await Record.get_by(self.db, id=vector_query.record_id)) if vector_query.record_id else None + + return await self.engine.similarity_search( + dataset=dataset, + vector_settings=vector_settings, + value=vector_query.value, + record=record, + query=text_query, + filter=filters, + order=vector_query.order, + max_results=max_results, + ) + + async def _search( + self, + dataset: Dataset, + search_query: SearchRecordsQuery, + offset: int, + limit: int, + user: Optional[User] = None, + ) -> SearchResponses: + filters = self._to_search_engine_filter(search_query.filters, user=user) + sort = self._to_search_engine_sort(search_query.sort, user) + text_query = search_query.text_query + + return await self.engine.search( + dataset=dataset, + query=text_query, + filter=filters, + sort=sort, + user_id=user.id if user else None, + offset=offset, + limit=limit, + ) + + async def _filter_records_metadata_for_user(self, records: Sequence[Record], user: User) -> None: + records_metadata = await asyncio.gather( + *[self._filter_record_metadata_for_user(record, user) for record in records] + ) + + for record, metadata in zip(records, records_metadata): + record.metadata_ = metadata + + @staticmethod + def _to_search_engine_filter_scope(scope: FilterScope, user: Optional[User]) -> search_engine.FilterScope: + if isinstance(scope, RecordFilterScope): + return search_engine.RecordFilterScope(property=scope.property.value) + elif isinstance(scope, MetadataFilterScope): + return search_engine.MetadataFilterScope(metadata_property=scope.metadata_property) + elif isinstance(scope, SuggestionFilterScope): + return search_engine.SuggestionFilterScope(question=scope.question, property=str(scope.property)) + elif isinstance(scope, ResponseFilterScope): + return search_engine.ResponseFilterScope(question=scope.question, property=scope.property, user=user) + else: + raise Exception(f"Unknown scope type {type(scope)}") + + def _to_search_engine_filter(self, filters: Filters, user: Optional[User]) -> Optional[search_engine.Filter]: + if filters is None: + return None + + engine_filters = [] + for filter in filters.and_: + engine_scope = self._to_search_engine_filter_scope(filter.scope, user=user) + + if isinstance(filter, TermsFilter): + engine_filter = search_engine.TermsFilter(scope=engine_scope, values=filter.values) + elif isinstance(filter, RangeFilter): + engine_filter = search_engine.RangeFilter(scope=engine_scope, ge=filter.ge, le=filter.le) + else: + raise Exception(f"Unknown filter type {type(filter)}") + + engine_filters.append(engine_filter) + + return AndFilter(filters=engine_filters) + + def _to_search_engine_sort(self, sort: List[Order], user: Optional[User]) -> Optional[List[search_engine.Order]]: + if sort is None: + return None + + engine_sort = [] + for order in sort: + engine_scope = self._to_search_engine_filter_scope(order.scope, user=user) + engine_sort.append(search_engine.Order(scope=engine_scope, order=order.order)) + + return engine_sort + + @staticmethod + async def _filter_record_metadata_for_user(record: Record, user: User) -> Optional[Dict[str, Any]]: + if record.metadata_ is None: + return None + + metadata = {} + for metadata_name in list(record.metadata_.keys()): + if await is_authorized(user, RecordPolicy.get_metadata(record, metadata_name)): + metadata[metadata_name] = record.metadata_[metadata_name] + + return metadata diff --git a/argilla-server/src/argilla_server/validators/search.py b/argilla-server/src/argilla_server/validators/search.py new file mode 100644 index 0000000000..6353ce0d30 --- /dev/null +++ b/argilla-server/src/argilla_server/validators/search.py @@ -0,0 +1,117 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.api.schemas.v1.records import ( + SearchRecordsQuery, + Query, + VectorQuery, + FilterScope, + RecordFilterScope, + MetadataFilterScope, +) +from argilla_server.api.schemas.v1.responses import ResponseFilterScope +from argilla_server.api.schemas.v1.suggestions import SuggestionFilterScope +from argilla_server.errors.future import NotFoundError, UnprocessableEntityError, MissingVectorError +from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE +from argilla_server.models import VectorSettings, Record, Field, Question, MetadataProperty +from argilla_server.search_engine import TextQuery + + +class SearchRecordsQueryValidator: + def __init__(self, db: AsyncSession, query: SearchRecordsQuery, dataset_id: UUID): + self._db = db + self._query = query + self._dataset_id = dataset_id + + async def validate(self) -> None: + try: + if self._query.filters: + for filter in self._query.filters.and_: + await self._validate_filter_scope(filter.scope) + + if self._query.sort: + for order in self._query.sort: + await self._validate_filter_scope(order.scope) + + if self._query.query: + await self._validate_search_query(self._query.query) + except NotFoundError as ex: + raise UnprocessableEntityError(str(ex)) from ex + + async def _validate_search_query(self, query: Query): + if query.text: + await self._validate_text(query.text) + if query.vector: + await self._validate_vector(query.vector) + + async def _validate_vector(self, vector_query: VectorQuery): + vector_settings = await VectorSettings.get_by(self._db, name=vector_query.name, dataset_id=self._dataset_id) + if vector_settings is None: + raise UnprocessableEntityError(f"Vector `{vector_query.name}` not found in dataset `{self._dataset_id}`.") + + if vector_query.record_id is not None: + record = await Record.get_by(self._db, id=vector_query.record_id, dataset_id=self._dataset_id) + if record is None: + raise UnprocessableEntityError( + f"Record with id `{vector_query.record_id}` not found in dataset `{self._dataset_id}`." + ) + + await record.awaitable_attrs.vectors + + if not record.vector_value_by_vector_settings(vector_settings): + # TODO: Once we move to v2.0 we can use here UnprocessableEntityError instead of MissingVectorError + raise MissingVectorError( + message=f"Record `{record.id}` does not have a vector for vector settings `{vector_settings.name}`", + code=MISSING_VECTOR_ERROR_CODE, + ) + + async def _validate_text(self, text_query: TextQuery): + if ( + text_query + and text_query.field + and not await Field.get_by(self._db, name=text_query.field, dataset_id=self._dataset_id) + ): + raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{self._dataset_id}`.") + + async def _validate_filter_scope(self, filter_scope: FilterScope) -> None: + if isinstance(filter_scope, RecordFilterScope): + return + elif isinstance(filter_scope, ResponseFilterScope): + await self._validate_response_filter_scope(filter_scope) + elif isinstance(filter_scope, SuggestionFilterScope): + await self._validate_suggestion_filter_scope(filter_scope) + elif isinstance(filter_scope, MetadataFilterScope): + await self._validate_metadata_filter_scope(filter_scope) + else: + raise ValueError(f"Unknown filter scope entity `{filter_scope.entity}`") + + async def _validate_response_filter_scope(self, filter_scope: ResponseFilterScope) -> None: + if filter_scope.question is None: + return + + await Question.get_by_or_raise(self._db, name=filter_scope.question, dataset_id=self._dataset_id) + + async def _validate_suggestion_filter_scope(self, filter_scope: SuggestionFilterScope) -> None: + await Question.get_by_or_raise(self._db, name=filter_scope.question, dataset_id=self._dataset_id) + + async def _validate_metadata_filter_scope(self, filter_scope: MetadataFilterScope) -> None: + await MetadataProperty.get_by_or_raise( + self._db, + name=filter_scope.metadata_property, + dataset_id=self._dataset_id, + ) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 5e3c6653de..ad7d0a259a 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -319,6 +319,8 @@ async def test_with_filter( offset=0, limit=50, query=None, + sort=None, + user_id=None, ) async def test_with_sort( @@ -367,6 +369,8 @@ async def test_with_sort( offset=0, limit=50, query=None, + filter=None, + user_id=None, ) async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict): diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index a259baa773..b8610719fc 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -3593,6 +3593,8 @@ async def test_search_current_user_dataset_records( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), + filter=None, + sort=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, user_id=owner.id, @@ -3754,6 +3756,7 @@ async def test_search_current_user_dataset_records_with_metadata_filter( filter=AndFilter(filters=[expected_filter]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, + sort=None, user_id=owner.id, ) @@ -3832,6 +3835,7 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), + filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, sort=expected_sort, @@ -4037,6 +4041,8 @@ async def test_search_current_user_dataset_records_with_include( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), + filter=None, + sort=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, user_id=owner.id, @@ -4268,6 +4274,7 @@ async def test_search_current_user_dataset_records_with_response_status_filter( ) ] ), + sort=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, user_id=owner.id, @@ -4308,6 +4315,7 @@ async def test_search_current_user_dataset_records_with_record_vector( dataset=dataset, vector_settings=vector_settings, record=records[0], + filter=None, value=None, query=None, order=SimilarityOrder.most_similar, @@ -4352,6 +4360,7 @@ async def test_search_current_user_dataset_records_with_vector_value( record=None, value=selected_vector.value, query=None, + filter=None, order=SimilarityOrder.most_similar, max_results=10, ) @@ -4399,6 +4408,7 @@ async def test_search_current_user_dataset_records_with_vector_value_and_query( record=None, value=selected_vector.value, query=TextQuery(q="Test query"), + filter=None, order=SimilarityOrder.most_similar, max_results=10, ) @@ -4492,6 +4502,8 @@ async def test_search_current_user_dataset_records_with_offset_and_limit( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), + filter=None, + sort=None, offset=0, limit=5, user_id=owner.id, diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index 4f989e5399..3dc3546d29 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -38,7 +38,6 @@ @pytest.mark.asyncio class TestSuiteListDatasetRecords: - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) @@ -58,25 +57,31 @@ async def test_list_dataset_records(self, async_client: "AsyncClient", owner_aut "items": [ { "id": str(record_a.id), + "dataset_id": str(dataset.id), "fields": {"record_a": "value_a"}, "metadata": None, "external_id": record_a.external_id, + "status": "pending", "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, { "id": str(record_b.id), + "dataset_id": str(dataset.id), "fields": {"record_b": "value_b"}, "metadata": {"unit": "test"}, "external_id": record_b.external_id, + "status": "pending", "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, { "id": str(record_c.id), + "dataset_id": str(dataset.id), "fields": {"record_c": "value_c"}, "metadata": None, "external_id": record_c.external_id, + "status": "pending", "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -188,7 +193,6 @@ async def test_list_dataset_records_with_include( assert response.status_code == 200 - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_include_vectors( self, async_client: "AsyncClient", owner_auth_header: dict ): @@ -214,6 +218,7 @@ async def test_list_dataset_records_with_include_vectors( "items": [ { "id": str(record_a.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -221,26 +226,31 @@ async def test_list_dataset_records_with_include_vectors( "vector-a": [1.0, 2.0, 3.0], "vector-b": [4.0, 5.0], }, + "status": "pending", "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, { "id": str(record_b.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, "vectors": { "vector-b": [1.0, 2.0], }, + "status": "pending", "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, { "id": str(record_c.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, "vectors": {}, + "status": "pending", "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -248,7 +258,6 @@ async def test_list_dataset_records_with_include_vectors( "total": 3, } - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_include_specific_vectors( self, async_client: "AsyncClient", owner_auth_header: dict ): @@ -278,6 +287,7 @@ async def test_list_dataset_records_with_include_specific_vectors( "items": [ { "id": str(record_a.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -285,26 +295,31 @@ async def test_list_dataset_records_with_include_specific_vectors( "vector-a": [1.0, 2.0, 3.0], "vector-b": [4.0, 5.0], }, + "status": "pending", "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, { "id": str(record_b.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, "vectors": { "vector-b": [1.0, 2.0], }, + "status": "pending", "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, { "id": str(record_c.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, "vectors": {}, + "status": "pending", "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -312,7 +327,6 @@ async def test_list_dataset_records_with_include_specific_vectors( "total": 3, } - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_offset(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) @@ -331,7 +345,6 @@ async def test_list_dataset_records_with_offset(self, async_client: "AsyncClient response_body = response.json() assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_limit(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) @@ -350,7 +363,6 @@ async def test_list_dataset_records_with_limit(self, async_client: "AsyncClient" response_body = response.json() assert [item["id"] for item in response_body["items"]] == [str(record_a.id)] - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_offset_and_limit( self, async_client: "AsyncClient", owner_auth_header: dict ): @@ -457,9 +469,9 @@ async def test_list_dataset_records_as_admin(self, async_client: "AsyncClient"): admin = await AdminFactory.create(workspaces=[workspace]) dataset = await DatasetFactory.create(workspace=workspace) - await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) + record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) + record_b = await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) + record_c = await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) other_dataset = await DatasetFactory.create() await RecordFactory.create_batch(size=2, dataset=other_dataset) @@ -468,6 +480,41 @@ async def test_list_dataset_records_as_admin(self, async_client: "AsyncClient"): f"/api/v1/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: admin.api_key} ) assert response.status_code == 200 + assert response.json() == { + "total": 3, + "items": [ + { + "id": str(record_a.id), + "dataset_id": str(dataset.id), + "fields": {"record_a": "value_a"}, + "metadata": None, + "external_id": record_a.external_id, + "status": "pending", + "inserted_at": record_a.inserted_at.isoformat(), + "updated_at": record_a.updated_at.isoformat(), + }, + { + "id": str(record_b.id), + "dataset_id": str(dataset.id), + "fields": {"record_b": "value_b"}, + "metadata": None, + "external_id": record_b.external_id, + "status": "pending", + "inserted_at": record_b.inserted_at.isoformat(), + "updated_at": record_b.updated_at.isoformat(), + }, + { + "id": str(record_c.id), + "dataset_id": str(dataset.id), + "fields": {"record_c": "value_c"}, + "metadata": None, + "external_id": record_c.external_id, + "status": "pending", + "inserted_at": record_c.inserted_at.isoformat(), + "updated_at": record_c.updated_at.isoformat(), + }, + ], + } async def test_list_dataset_records_as_annotator(self, async_client: "AsyncClient"): workspace = await WorkspaceFactory.create() diff --git a/argilla-server/tests/unit/contexts/search/test_search_records_query_validator.py b/argilla-server/tests/unit/contexts/search/test_search_records_query_validator.py index 3972283865..011dc3209e 100644 --- a/argilla-server/tests/unit/contexts/search/test_search_records_query_validator.py +++ b/argilla-server/tests/unit/contexts/search/test_search_records_query_validator.py @@ -17,7 +17,7 @@ import argilla_server.errors.future as errors import pytest from argilla_server.api.schemas.v1.records import SearchRecordsQuery -from argilla_server.contexts.search import SearchRecordsQueryValidator +from argilla_server.validators.search import SearchRecordsQueryValidator from sqlalchemy.ext.asyncio import AsyncSession from tests.factories import ( @@ -118,7 +118,7 @@ async def test_validate_response_filter_scope_in_filters_with_non_existent_quest } ) - with pytest.raises(errors.NotFoundError) as not_found_error: + with pytest.raises(errors.UnprocessableEntityError) as not_found_error: await SearchRecordsQueryValidator(db, query, dataset.id).validate() assert ( @@ -145,7 +145,7 @@ async def test_validate_suggestion_filter_scope_in_filters_with_non_existent_que } ) - with pytest.raises(errors.NotFoundError) as not_found_error: + with pytest.raises(errors.UnprocessableEntityError) as not_found_error: await SearchRecordsQueryValidator(db, query, dataset.id).validate() assert ( @@ -174,7 +174,7 @@ async def test_validate_metadata_filter_scope_in_filters_with_non_existent_metad } ) - with pytest.raises(errors.NotFoundError) as not_found_error: + with pytest.raises(errors.UnprocessableEntityError) as not_found_error: await SearchRecordsQueryValidator(db, query, dataset.id).validate() assert ( @@ -206,7 +206,7 @@ async def test_validate_response_filter_scope_in_sort_with_non_existent_question } ) - with pytest.raises(errors.NotFoundError) as not_found_error: + with pytest.raises(errors.UnprocessableEntityError) as not_found_error: await SearchRecordsQueryValidator(db, query, dataset.id).validate() assert ( @@ -225,7 +225,7 @@ async def test_validate_suggestion_filter_scope_in_sort_with_non_existent_questi } ) - with pytest.raises(errors.NotFoundError) as not_found_error: + with pytest.raises(errors.UnprocessableEntityError) as not_found_error: await SearchRecordsQueryValidator(db, query, dataset.id).validate() assert ( @@ -244,7 +244,7 @@ async def test_validate_metadata_filter_scope_in_sort_with_non_existent_metadata } ) - with pytest.raises(errors.NotFoundError) as not_found_error: + with pytest.raises(errors.UnprocessableEntityError) as not_found_error: await SearchRecordsQueryValidator(db, query, dataset.id).validate() assert (