diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 05a8e7e583..b3f1483986 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -28,8 +28,9 @@ These are the section headers that we use: ### Removed - Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) -- [breaking] Removed support for `response_status` query param. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) -- [breaking] Removed support for `metadata` query param. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) +- [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) +- [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) +- [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166)) ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 4f68c1125f..065295612f 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -19,7 +19,6 @@ from fastapi import APIRouter, Depends, Query, Security, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from typing_extensions import Annotated import argilla_server.search_engine as search_engine from argilla_server.api.policies.v1 import DatasetPolicy, RecordPolicy, authorize, is_authorized @@ -52,12 +51,11 @@ from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE -from argilla_server.models import Dataset, Field, MetadataProperty, Record, User, VectorSettings +from argilla_server.models import Dataset, Field, Record, User, VectorSettings from argilla_server.search_engine import ( AndFilter, SearchEngine, SearchResponses, - SortBy, UserResponseStatusFilter, get_search_engine, ) @@ -70,25 +68,6 @@ LIST_DATASET_RECORDS_DEFAULT_SORT_BY = {RecordSortField.inserted_at.value: "asc"} DELETE_DATASET_RECORDS_LIMIT = 100 -_RECORD_SORT_FIELD_VALUES = tuple(field.value for field in RecordSortField) -_VALID_SORT_VALUES = tuple(sort.value for sort in SortOrder) -_METADATA_PROPERTY_SORT_BY_REGEX = re.compile(r"^metadata\.(?P(?=.*[a-z0-9])[a-z0-9_-]+)$") - -SortByQueryParamParsed = Annotated[ - Dict[str, str], - Depends( - parse_query_param( - name="sort_by", - description=( - "The field used to sort the records. Expected format is `field` or `field:{asc,desc}`, where `field`" - " can be 'inserted_at', 'updated_at' or the name of a metadata property" - ), - max_values_per_key=1, - group_keys_without_values=False, - ) - ), -] - parse_record_include_param = parse_query_param( name="include", help="Relationships to include in the response", model=RecordIncludeParam ) @@ -104,7 +83,6 @@ async def _filter_records_using_search_engine( offset: int, user: Optional[User] = None, include: Optional[RecordIncludeParam] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> Tuple[List[Record], int]: search_responses = await _get_search_responses( db=db, @@ -113,7 +91,6 @@ async def _filter_records_using_search_engine( limit=limit, offset=offset, user=user, - sort_by_query_param=sort_by_query_param, ) record_ids = [response.record_id for response in search_responses.items] @@ -176,7 +153,6 @@ async def _get_search_responses( offset: int, search_records_query: Optional[SearchRecordsQuery] = None, user: Optional[User] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> "SearchResponses": search_records_query = search_records_query or SearchRecordsQuery() @@ -216,8 +192,6 @@ async def _get_search_responses( if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id): raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.") - sort_by = await _build_sort_by(db, dataset, sort_by_query_param) - if vector_query and vector_settings: similarity_search_params = { "dataset": dataset, @@ -239,7 +213,6 @@ async def _get_search_responses( "query": text_query, "offset": offset, "limit": limit, - "sort_by": sort_by, } if user is not None: @@ -265,43 +238,6 @@ async def _build_response_status_filter_for_search( return user_response_status_filter -async def _build_sort_by( - db: "AsyncSession", dataset: Dataset, sort_by_query_param: Optional[Dict[str, str]] = None -) -> Union[List[SortBy], None]: - if sort_by_query_param is None: - return None - - sorts_by = [] - for sort_field, sort_order in sort_by_query_param.items(): - if sort_field in _RECORD_SORT_FIELD_VALUES: - field = sort_field - elif (match := _METADATA_PROPERTY_SORT_BY_REGEX.match(sort_field)) is not None: - metadata_property_name = match.group("name") - metadata_property = await MetadataProperty.get_by(db, name=metadata_property_name, dataset_id=dataset.id) - if not metadata_property: - raise UnprocessableEntityError( - f"Provided metadata property in 'sort_by' query param '{metadata_property_name}' not found in " - f"dataset with '{dataset.id}'." - ) - - field = metadata_property - else: - valid_sort_fields = ", ".join(f"'{sort_field}'" for sort_field in _RECORD_SORT_FIELD_VALUES) - raise UnprocessableEntityError( - f"Provided sort field in 'sort_by' query param '{sort_field}' is not valid. It must be either" - f" {valid_sort_fields} or `metadata.metadata-property-name`" - ) - - if sort_order is not None and sort_order not in _VALID_SORT_VALUES: - raise UnprocessableEntityError( - f"Provided sort order in 'sort_by' query param '{sort_order}' for field '{sort_field}' is not valid.", - ) - - sorts_by.append(SortBy(field=field, order=sort_order or SortOrder.asc.value)) - - return sorts_by - - async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID): try: await search.validate_search_records_query(db, query, dataset_id) @@ -315,7 +251,6 @@ async def list_dataset_records( db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), @@ -332,7 +267,6 @@ async def list_dataset_records( limit=limit, offset=offset, include=include, - sort_by_query_param=sort_by_query_param or LIST_DATASET_RECORDS_DEFAULT_SORT_BY, ) return Records(items=records, total=total) @@ -441,7 +375,6 @@ async def search_current_user_dataset_records( telemetry_client: TelemetryClient = Depends(get_telemetry_client), dataset_id: UUID, body: SearchRecordsQuery, - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), @@ -468,7 +401,6 @@ async def search_current_user_dataset_records( limit=limit, offset=offset, user=current_user, - sort_by_query_param=sort_by_query_param, ) record_id_score_map: Dict[UUID, Dict[str, Union[float, SearchRecord, None]]] = { @@ -511,7 +443,6 @@ async def search_dataset_records( search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, body: SearchRecordsQuery, - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), @@ -530,7 +461,6 @@ async def search_dataset_records( search_records_query=body, limit=limit, offset=offset, - sort_by_query_param=sort_by_query_param, ) record_id_score_map = { diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index 687c51bad3..db5bc87e2a 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -282,9 +282,6 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: remove them and keep filter and order - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, ) -> SearchResponses: diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 0b7606c642..5b9d5e66bc 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -199,23 +199,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str: return str(vector_settings.id) -# This function will be moved once the `sort_by` argument is removed from search and similarity_search methods -def _unify_sort_by_with_order(sort_by: List[SortBy], order: List[Order]) -> List[Order]: - if order: - return order - - new_order = [] - for sort in sort_by: - if isinstance(sort.field, MetadataProperty): - scope = MetadataFilterScope(metadata_property=sort.field.name) - else: - scope = RecordFilterScope(property=sort.field) - - new_order.append(Order(scope=scope, order=sort.order)) - - return new_order - - def is_response_status_scope(scope: FilterScope) -> bool: return isinstance(scope, ResponseFilterScope) and scope.property == "status" and scope.question is None @@ -327,14 +310,14 @@ async def update_record_response(self, response: Response): es_responses = self._map_record_responses_to_es([response]) - await self._update_document_request(index_name, id=record.id, body={"doc": {"responses": es_responses}}) + await self._update_document_request(index_name, id=str(record.id), body={"doc": {"responses": es_responses}}) async def delete_record_response(self, response: Response): record = response.record index_name = await self._get_dataset_index(record.dataset) await self._update_document_request( - index_name, id=record.id, body={"script": es_script_for_delete_user_response(response.user)} + index_name, id=str(record.id), body={"script": es_script_for_delete_user_response(response.user)} ) async def update_record_suggestion(self, suggestion: Suggestion): @@ -344,7 +327,7 @@ async def update_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"doc": {"suggestions": es_suggestions}}, ) @@ -353,7 +336,7 @@ async def delete_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'}, ) @@ -576,19 +559,11 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: Remove these arguments - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, user_id: Optional[str] = None, ) -> SearchResponses: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html - - # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if sort_by: - sort = _unify_sort_by_with_order(sort_by, sort) - # END TODO index = await self._get_dataset_index(dataset) text_query = self._build_text_query(dataset, text=query) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 5229f53cf9..5e3c6653de 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -319,7 +319,6 @@ async def test_with_filter( offset=0, limit=50, query=None, - sort_by=None, ) async def test_with_sort( @@ -368,7 +367,6 @@ async def test_with_sort( offset=0, limit=50, query=None, - sort_by=None, ) async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict): diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 1d9ddcf22c..f84154b4c4 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -14,7 +14,7 @@ import math import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type from unittest.mock import ANY, MagicMock from uuid import UUID, uuid4 @@ -43,6 +43,7 @@ ResponseStatusFilter, SimilarityOrder, RecordStatus, + SortOrder, ) from argilla_server.models import ( Dataset, @@ -62,13 +63,14 @@ SearchEngine, SearchResponseItem, SearchResponses, - SortBy, TextQuery, AndFilter, TermsFilter, MetadataFilterScope, RangeFilter, ResponseFilterScope, + Order, + RecordFilterScope, ) from tests.factories import ( AdminFactory, @@ -3651,7 +3653,6 @@ async def test_search_current_user_dataset_records( query=TextQuery(q="Hello", field="input"), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -3811,31 +3812,42 @@ async def test_search_current_user_dataset_records_with_metadata_filter( filter=AndFilter(filters=[expected_filter]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) @pytest.mark.parametrize( - "sorts", + "sort,expected_sort", [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "asc"}], + [Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.asc)], + ), + ( + [ + {"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}, + {"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "desc"}, + ], + [ + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), + Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.desc), + ], + ), ], ) async def test_search_current_user_dataset_records_with_sort_by( @@ -3844,16 +3856,15 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine: SearchEngine, owner: "User", owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], + sort: List[dict], + expected_sort: List[Order], ): workspace = await WorkspaceFactory.create() dataset, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) + for order in expected_sort: + if isinstance(order.scope, MetadataFilterScope): + await TermsMetadataPropertyFactory.create(name=order.scope.metadata_property, dataset=dataset) mock_search_engine.search.return_value = SearchResponses( total=2, @@ -3863,15 +3874,13 @@ async def test_search_current_user_dataset_records_with_sort_by( ], ) - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": sort, } - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} - response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params=query_params, headers=owner_auth_header, json=query_json, ) @@ -3883,7 +3892,7 @@ async def test_search_current_user_dataset_records_with_sort_by( query=TextQuery(q="Hello", field="input"), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, + sort=expected_sort, user_id=owner.id, ) @@ -3893,18 +3902,17 @@ async def test_search_current_user_dataset_records_with_sort_by_with_wrong_sort_ workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "record", "property": "wrong_property"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } async def test_search_current_user_dataset_records_with_sort_by_with_non_existent_metadata_property( self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict @@ -3912,17 +3920,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_non_existen workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "metadata", "metadata_property": "missing"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "metadata.i-do-not-exist:asc"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." + "detail": f"MetadataProperty not found filtering by name=missing, dataset_id={dataset.id}" } async def test_search_current_user_dataset_records_with_sort_by_with_invalid_field( @@ -3931,19 +3941,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_invalid_fie workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [ + {"scope": {"entity": "wrong", "property": "wrong"}, "order": "asc"}, + ], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "not-valid"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } @pytest.mark.parametrize( "includes", @@ -4085,7 +4095,6 @@ async def test_search_current_user_dataset_records_with_include( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - sort_by=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, user_id=owner.id, @@ -4319,7 +4328,6 @@ async def test_search_current_user_dataset_records_with_response_status_filter( ), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -4544,7 +4552,6 @@ async def test_search_current_user_dataset_records_with_offset_and_limit( query=TextQuery(q="Hello", field="input"), offset=0, limit=5, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index db2605c3de..4f989e5399 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -17,16 +17,9 @@ import pytest from httpx import AsyncClient -from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.constants import API_KEY_HEADER_NAME from argilla_server.enums import RecordInclude, ResponseStatus from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace -from argilla_server.search_engine import ( - SearchEngine, - SearchResponseItem, - SearchResponses, - SortBy, -) from tests.factories import ( AdminFactory, AnnotatorFactory, @@ -35,7 +28,6 @@ RecordFactory, ResponseFactory, SuggestionFactory, - TermsMetadataPropertyFactory, TextFieldFactory, TextQuestionFactory, VectorFactory, @@ -453,119 +445,6 @@ async def test_list_dataset_records_with_response_status_filter( ] ) - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - ], - ) - async def test_list_dataset_records_with_sort_by( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: "User", - owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] - } - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - assert response.json()["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, - ) - - async def test_list_dataset_records_with_sort_by_with_wrong_sort_order_value( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } - - async def test_list_dataset_records_with_sort_by_with_non_existent_metadata_property( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "metadata.i-do-not-exist:asc"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." - } - - async def test_list_dataset_records_with_sort_by_with_invalid_field( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "not-valid"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } - async def test_list_dataset_records_without_authentication(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index 5ae8241927..c893366b58 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -16,7 +16,14 @@ import pytest import pytest_asyncio -from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder, RecordStatus +from argilla_server.enums import ( + MetadataPropertyType, + QuestionType, + ResponseStatusFilter, + SimilarityOrder, + RecordStatus, + SortOrder, +) from argilla_server.models import Dataset, Question, Record, User, VectorSettings from argilla_server.search_engine import ( ResponseFilterScope, @@ -28,6 +35,8 @@ Filter, MetadataFilterScope, RangeFilter, + Order, + RecordFilterScope, ) from argilla_server.search_engine.commons import ( ALL_RESPONSES_STATUSES_FIELD, @@ -820,12 +829,12 @@ async def test_search_with_pagination( assert all_results.items[offset : offset + limit] == results.items @pytest.mark.parametrize( - ("sort_by"), + ("sort_order"), [ - SortBy(field="inserted_at"), - SortBy(field="updated_at"), - SortBy(field="inserted_at", order="desc"), - SortBy(field="updated_at", order="desc"), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), ], ) async def test_search_with_sort_by( @@ -833,18 +842,15 @@ async def test_search_with_sort_by( search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - sort_by: SortBy, + sort_order: Order, ): def _local_sort_by(record: Record) -> Any: - if isinstance(sort_by.field, str): - return getattr(record, sort_by.field) - return record.metadata_[sort_by.field.name] + return getattr(record, sort_order.scope.property) - results = await search_engine.search(test_banking_sentiment_dataset, sort_by=[sort_by]) + results = await search_engine.search(test_banking_sentiment_dataset, sort=[sort_order]) records = test_banking_sentiment_dataset.records - if sort_by: - records = sorted(records, key=_local_sort_by, reverse=sort_by.order == "desc") + records = sorted(records, key=_local_sort_by, reverse=sort_order.order == "desc") assert [item.record_id for item in results.items] == [record.id for record in records]