Skip to content

Commit

Permalink
[BREAKING - REFACTOR] argilla-server: remove sort_by query param (#…
Browse files Browse the repository at this point in the history
…5166)

# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR removes support of `sort_by` query param for list/search records
endpoints.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- Refactor (change restructuring the codebase without changing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jul 8, 2024
1 parent 0404465 commit ba417dc
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 293 deletions.
5 changes: 3 additions & 2 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ These are the section headers that we use:
### Removed

- Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153))
- [breaking] Removed support for `response_status` query param. ([#5163](https://github.com/argilla-io/argilla/pull/5163))
- [breaking] Removed support for `metadata` query param. ([#5156](https://github.com/argilla-io/argilla/pull/5156))
- [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163))
- [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156))
- [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166))

## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from fastapi import APIRouter, Depends, Query, Security, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from typing_extensions import Annotated

import argilla_server.search_engine as search_engine
from argilla_server.api.policies.v1 import DatasetPolicy, RecordPolicy, authorize, is_authorized
Expand Down Expand Up @@ -52,12 +51,11 @@
from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder
from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError
from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE
from argilla_server.models import Dataset, Field, MetadataProperty, Record, User, VectorSettings
from argilla_server.models import Dataset, Field, Record, User, VectorSettings
from argilla_server.search_engine import (
AndFilter,
SearchEngine,
SearchResponses,
SortBy,
UserResponseStatusFilter,
get_search_engine,
)
Expand All @@ -70,25 +68,6 @@
LIST_DATASET_RECORDS_DEFAULT_SORT_BY = {RecordSortField.inserted_at.value: "asc"}
DELETE_DATASET_RECORDS_LIMIT = 100

_RECORD_SORT_FIELD_VALUES = tuple(field.value for field in RecordSortField)
_VALID_SORT_VALUES = tuple(sort.value for sort in SortOrder)
_METADATA_PROPERTY_SORT_BY_REGEX = re.compile(r"^metadata\.(?P<name>(?=.*[a-z0-9])[a-z0-9_-]+)$")

SortByQueryParamParsed = Annotated[
Dict[str, str],
Depends(
parse_query_param(
name="sort_by",
description=(
"The field used to sort the records. Expected format is `field` or `field:{asc,desc}`, where `field`"
" can be 'inserted_at', 'updated_at' or the name of a metadata property"
),
max_values_per_key=1,
group_keys_without_values=False,
)
),
]

parse_record_include_param = parse_query_param(
name="include", help="Relationships to include in the response", model=RecordIncludeParam
)
Expand All @@ -104,7 +83,6 @@ async def _filter_records_using_search_engine(
offset: int,
user: Optional[User] = None,
include: Optional[RecordIncludeParam] = None,
sort_by_query_param: Optional[Dict[str, str]] = None,
) -> Tuple[List[Record], int]:
search_responses = await _get_search_responses(
db=db,
Expand All @@ -113,7 +91,6 @@ async def _filter_records_using_search_engine(
limit=limit,
offset=offset,
user=user,
sort_by_query_param=sort_by_query_param,
)

record_ids = [response.record_id for response in search_responses.items]
Expand Down Expand Up @@ -176,7 +153,6 @@ async def _get_search_responses(
offset: int,
search_records_query: Optional[SearchRecordsQuery] = None,
user: Optional[User] = None,
sort_by_query_param: Optional[Dict[str, str]] = None,
) -> "SearchResponses":
search_records_query = search_records_query or SearchRecordsQuery()

Expand Down Expand Up @@ -216,8 +192,6 @@ async def _get_search_responses(
if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id):
raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.")

sort_by = await _build_sort_by(db, dataset, sort_by_query_param)

if vector_query and vector_settings:
similarity_search_params = {
"dataset": dataset,
Expand All @@ -239,7 +213,6 @@ async def _get_search_responses(
"query": text_query,
"offset": offset,
"limit": limit,
"sort_by": sort_by,
}

if user is not None:
Expand All @@ -265,43 +238,6 @@ async def _build_response_status_filter_for_search(
return user_response_status_filter


async def _build_sort_by(
db: "AsyncSession", dataset: Dataset, sort_by_query_param: Optional[Dict[str, str]] = None
) -> Union[List[SortBy], None]:
if sort_by_query_param is None:
return None

sorts_by = []
for sort_field, sort_order in sort_by_query_param.items():
if sort_field in _RECORD_SORT_FIELD_VALUES:
field = sort_field
elif (match := _METADATA_PROPERTY_SORT_BY_REGEX.match(sort_field)) is not None:
metadata_property_name = match.group("name")
metadata_property = await MetadataProperty.get_by(db, name=metadata_property_name, dataset_id=dataset.id)
if not metadata_property:
raise UnprocessableEntityError(
f"Provided metadata property in 'sort_by' query param '{metadata_property_name}' not found in "
f"dataset with '{dataset.id}'."
)

field = metadata_property
else:
valid_sort_fields = ", ".join(f"'{sort_field}'" for sort_field in _RECORD_SORT_FIELD_VALUES)
raise UnprocessableEntityError(
f"Provided sort field in 'sort_by' query param '{sort_field}' is not valid. It must be either"
f" {valid_sort_fields} or `metadata.metadata-property-name`"
)

if sort_order is not None and sort_order not in _VALID_SORT_VALUES:
raise UnprocessableEntityError(
f"Provided sort order in 'sort_by' query param '{sort_order}' for field '{sort_field}' is not valid.",
)

sorts_by.append(SortBy(field=field, order=sort_order or SortOrder.asc.value))

return sorts_by


async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID):
try:
await search.validate_search_records_query(db, query, dataset_id)
Expand All @@ -315,7 +251,6 @@ async def list_dataset_records(
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
Expand All @@ -332,7 +267,6 @@ async def list_dataset_records(
limit=limit,
offset=offset,
include=include,
sort_by_query_param=sort_by_query_param or LIST_DATASET_RECORDS_DEFAULT_SORT_BY,
)

return Records(items=records, total=total)
Expand Down Expand Up @@ -441,7 +375,6 @@ async def search_current_user_dataset_records(
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
dataset_id: UUID,
body: SearchRecordsQuery,
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
Expand All @@ -468,7 +401,6 @@ async def search_current_user_dataset_records(
limit=limit,
offset=offset,
user=current_user,
sort_by_query_param=sort_by_query_param,
)

record_id_score_map: Dict[UUID, Dict[str, Union[float, SearchRecord, None]]] = {
Expand Down Expand Up @@ -511,7 +443,6 @@ async def search_dataset_records(
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
body: SearchRecordsQuery,
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
Expand All @@ -530,7 +461,6 @@ async def search_dataset_records(
search_records_query=body,
limit=limit,
offset=offset,
sort_by_query_param=sort_by_query_param,
)

record_id_score_map = {
Expand Down
3 changes: 0 additions & 3 deletions argilla-server/src/argilla_server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,6 @@ async def search(
query: Optional[Union[TextQuery, str]] = None,
filter: Optional[Filter] = None,
sort: Optional[List[Order]] = None,
# TODO: remove them and keep filter and order
sort_by: Optional[List[SortBy]] = None,
# END TODO
offset: int = 0,
limit: int = 100,
) -> SearchResponses:
Expand Down
33 changes: 4 additions & 29 deletions argilla-server/src/argilla_server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str:
return str(vector_settings.id)


# This function will be moved once the `sort_by` argument is removed from search and similarity_search methods
def _unify_sort_by_with_order(sort_by: List[SortBy], order: List[Order]) -> List[Order]:
if order:
return order

new_order = []
for sort in sort_by:
if isinstance(sort.field, MetadataProperty):
scope = MetadataFilterScope(metadata_property=sort.field.name)
else:
scope = RecordFilterScope(property=sort.field)

new_order.append(Order(scope=scope, order=sort.order))

return new_order


def is_response_status_scope(scope: FilterScope) -> bool:
return isinstance(scope, ResponseFilterScope) and scope.property == "status" and scope.question is None

Expand Down Expand Up @@ -327,14 +310,14 @@ async def update_record_response(self, response: Response):

es_responses = self._map_record_responses_to_es([response])

await self._update_document_request(index_name, id=record.id, body={"doc": {"responses": es_responses}})
await self._update_document_request(index_name, id=str(record.id), body={"doc": {"responses": es_responses}})

async def delete_record_response(self, response: Response):
record = response.record
index_name = await self._get_dataset_index(record.dataset)

await self._update_document_request(
index_name, id=record.id, body={"script": es_script_for_delete_user_response(response.user)}
index_name, id=str(record.id), body={"script": es_script_for_delete_user_response(response.user)}
)

async def update_record_suggestion(self, suggestion: Suggestion):
Expand All @@ -344,7 +327,7 @@ async def update_record_suggestion(self, suggestion: Suggestion):

await self._update_document_request(
index_name,
id=suggestion.record_id,
id=str(suggestion.record_id),
body={"doc": {"suggestions": es_suggestions}},
)

Expand All @@ -353,7 +336,7 @@ async def delete_record_suggestion(self, suggestion: Suggestion):

await self._update_document_request(
index_name,
id=suggestion.record_id,
id=str(suggestion.record_id),
body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'},
)

Expand Down Expand Up @@ -576,19 +559,11 @@ async def search(
query: Optional[Union[TextQuery, str]] = None,
filter: Optional[Filter] = None,
sort: Optional[List[Order]] = None,
# TODO: Remove these arguments
sort_by: Optional[List[SortBy]] = None,
# END TODO
offset: int = 0,
limit: int = 100,
user_id: Optional[str] = None,
) -> SearchResponses:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html

# TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept
if sort_by:
sort = _unify_sort_by_with_order(sort_by, sort)
# END TODO
index = await self._get_dataset_index(dataset)

text_query = self._build_text_query(dataset, text=query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ async def test_with_filter(
offset=0,
limit=50,
query=None,
sort_by=None,
)

async def test_with_sort(
Expand Down Expand Up @@ -368,7 +367,6 @@ async def test_with_sort(
offset=0,
limit=50,
query=None,
sort_by=None,
)

async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict):
Expand Down
Loading

0 comments on commit ba417dc

Please sign in to comment.