Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING - REFACTOR] argilla-server: remove sort_by query param #5166

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ These are the section headers that we use:
### Removed

- Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153))
- [breaking] Removed support for `response_status` query param. ([#5163](https://github.com/argilla-io/argilla/pull/5163))
- [breaking] Removed support for `metadata` query param. ([#5156](https://github.com/argilla-io/argilla/pull/5156))
- [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163))
- [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156))
- [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166))

## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from fastapi import APIRouter, Depends, Query, Security, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from typing_extensions import Annotated

import argilla_server.search_engine as search_engine
from argilla_server.api.policies.v1 import DatasetPolicy, RecordPolicy, authorize, is_authorized
Expand Down Expand Up @@ -52,12 +51,11 @@
from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder
from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError
from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE
from argilla_server.models import Dataset, Field, MetadataProperty, Record, User, VectorSettings
from argilla_server.models import Dataset, Field, Record, User, VectorSettings
from argilla_server.search_engine import (
AndFilter,
SearchEngine,
SearchResponses,
SortBy,
UserResponseStatusFilter,
get_search_engine,
)
Expand All @@ -70,25 +68,6 @@
LIST_DATASET_RECORDS_DEFAULT_SORT_BY = {RecordSortField.inserted_at.value: "asc"}
DELETE_DATASET_RECORDS_LIMIT = 100

_RECORD_SORT_FIELD_VALUES = tuple(field.value for field in RecordSortField)
_VALID_SORT_VALUES = tuple(sort.value for sort in SortOrder)
_METADATA_PROPERTY_SORT_BY_REGEX = re.compile(r"^metadata\.(?P<name>(?=.*[a-z0-9])[a-z0-9_-]+)$")

SortByQueryParamParsed = Annotated[
Dict[str, str],
Depends(
parse_query_param(
name="sort_by",
description=(
"The field used to sort the records. Expected format is `field` or `field:{asc,desc}`, where `field`"
" can be 'inserted_at', 'updated_at' or the name of a metadata property"
),
max_values_per_key=1,
group_keys_without_values=False,
)
),
]

parse_record_include_param = parse_query_param(
name="include", help="Relationships to include in the response", model=RecordIncludeParam
)
Expand All @@ -104,7 +83,6 @@ async def _filter_records_using_search_engine(
offset: int,
user: Optional[User] = None,
include: Optional[RecordIncludeParam] = None,
sort_by_query_param: Optional[Dict[str, str]] = None,
) -> Tuple[List[Record], int]:
search_responses = await _get_search_responses(
db=db,
Expand All @@ -113,7 +91,6 @@ async def _filter_records_using_search_engine(
limit=limit,
offset=offset,
user=user,
sort_by_query_param=sort_by_query_param,
)

record_ids = [response.record_id for response in search_responses.items]
Expand Down Expand Up @@ -176,7 +153,6 @@ async def _get_search_responses(
offset: int,
search_records_query: Optional[SearchRecordsQuery] = None,
user: Optional[User] = None,
sort_by_query_param: Optional[Dict[str, str]] = None,
) -> "SearchResponses":
search_records_query = search_records_query or SearchRecordsQuery()

Expand Down Expand Up @@ -216,8 +192,6 @@ async def _get_search_responses(
if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id):
raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.")

sort_by = await _build_sort_by(db, dataset, sort_by_query_param)

if vector_query and vector_settings:
similarity_search_params = {
"dataset": dataset,
Expand All @@ -239,7 +213,6 @@ async def _get_search_responses(
"query": text_query,
"offset": offset,
"limit": limit,
"sort_by": sort_by,
}

if user is not None:
Expand All @@ -265,43 +238,6 @@ async def _build_response_status_filter_for_search(
return user_response_status_filter


async def _build_sort_by(
db: "AsyncSession", dataset: Dataset, sort_by_query_param: Optional[Dict[str, str]] = None
) -> Union[List[SortBy], None]:
if sort_by_query_param is None:
return None

sorts_by = []
for sort_field, sort_order in sort_by_query_param.items():
if sort_field in _RECORD_SORT_FIELD_VALUES:
field = sort_field
elif (match := _METADATA_PROPERTY_SORT_BY_REGEX.match(sort_field)) is not None:
metadata_property_name = match.group("name")
metadata_property = await MetadataProperty.get_by(db, name=metadata_property_name, dataset_id=dataset.id)
if not metadata_property:
raise UnprocessableEntityError(
f"Provided metadata property in 'sort_by' query param '{metadata_property_name}' not found in "
f"dataset with '{dataset.id}'."
)

field = metadata_property
else:
valid_sort_fields = ", ".join(f"'{sort_field}'" for sort_field in _RECORD_SORT_FIELD_VALUES)
raise UnprocessableEntityError(
f"Provided sort field in 'sort_by' query param '{sort_field}' is not valid. It must be either"
f" {valid_sort_fields} or `metadata.metadata-property-name`"
)

if sort_order is not None and sort_order not in _VALID_SORT_VALUES:
raise UnprocessableEntityError(
f"Provided sort order in 'sort_by' query param '{sort_order}' for field '{sort_field}' is not valid.",
)

sorts_by.append(SortBy(field=field, order=sort_order or SortOrder.asc.value))

return sorts_by


async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID):
try:
await search.validate_search_records_query(db, query, dataset_id)
Expand All @@ -315,7 +251,6 @@ async def list_dataset_records(
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
Expand All @@ -332,7 +267,6 @@ async def list_dataset_records(
limit=limit,
offset=offset,
include=include,
sort_by_query_param=sort_by_query_param or LIST_DATASET_RECORDS_DEFAULT_SORT_BY,
)

return Records(items=records, total=total)
Expand Down Expand Up @@ -441,7 +375,6 @@ async def search_current_user_dataset_records(
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
dataset_id: UUID,
body: SearchRecordsQuery,
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
Expand All @@ -468,7 +401,6 @@ async def search_current_user_dataset_records(
limit=limit,
offset=offset,
user=current_user,
sort_by_query_param=sort_by_query_param,
)

record_id_score_map: Dict[UUID, Dict[str, Union[float, SearchRecord, None]]] = {
Expand Down Expand Up @@ -511,7 +443,6 @@ async def search_dataset_records(
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
body: SearchRecordsQuery,
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
Expand All @@ -530,7 +461,6 @@ async def search_dataset_records(
search_records_query=body,
limit=limit,
offset=offset,
sort_by_query_param=sort_by_query_param,
)

record_id_score_map = {
Expand Down
3 changes: 0 additions & 3 deletions argilla-server/src/argilla_server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,6 @@ async def search(
query: Optional[Union[TextQuery, str]] = None,
filter: Optional[Filter] = None,
sort: Optional[List[Order]] = None,
# TODO: remove them and keep filter and order
sort_by: Optional[List[SortBy]] = None,
# END TODO
offset: int = 0,
limit: int = 100,
) -> SearchResponses:
Expand Down
33 changes: 4 additions & 29 deletions argilla-server/src/argilla_server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str:
return str(vector_settings.id)


# This function will be moved once the `sort_by` argument is removed from search and similarity_search methods
def _unify_sort_by_with_order(sort_by: List[SortBy], order: List[Order]) -> List[Order]:
if order:
return order

new_order = []
for sort in sort_by:
if isinstance(sort.field, MetadataProperty):
scope = MetadataFilterScope(metadata_property=sort.field.name)
else:
scope = RecordFilterScope(property=sort.field)

new_order.append(Order(scope=scope, order=sort.order))

return new_order


def is_response_status_scope(scope: FilterScope) -> bool:
return isinstance(scope, ResponseFilterScope) and scope.property == "status" and scope.question is None

Expand Down Expand Up @@ -327,14 +310,14 @@ async def update_record_response(self, response: Response):

es_responses = self._map_record_responses_to_es([response])

await self._update_document_request(index_name, id=record.id, body={"doc": {"responses": es_responses}})
await self._update_document_request(index_name, id=str(record.id), body={"doc": {"responses": es_responses}})

async def delete_record_response(self, response: Response):
record = response.record
index_name = await self._get_dataset_index(record.dataset)

await self._update_document_request(
index_name, id=record.id, body={"script": es_script_for_delete_user_response(response.user)}
index_name, id=str(record.id), body={"script": es_script_for_delete_user_response(response.user)}
)

async def update_record_suggestion(self, suggestion: Suggestion):
Expand All @@ -344,7 +327,7 @@ async def update_record_suggestion(self, suggestion: Suggestion):

await self._update_document_request(
index_name,
id=suggestion.record_id,
id=str(suggestion.record_id),
body={"doc": {"suggestions": es_suggestions}},
)

Expand All @@ -353,7 +336,7 @@ async def delete_record_suggestion(self, suggestion: Suggestion):

await self._update_document_request(
index_name,
id=suggestion.record_id,
id=str(suggestion.record_id),
body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'},
)

Expand Down Expand Up @@ -576,19 +559,11 @@ async def search(
query: Optional[Union[TextQuery, str]] = None,
filter: Optional[Filter] = None,
sort: Optional[List[Order]] = None,
# TODO: Remove these arguments
sort_by: Optional[List[SortBy]] = None,
# END TODO
offset: int = 0,
limit: int = 100,
user_id: Optional[str] = None,
) -> SearchResponses:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html

# TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept
if sort_by:
sort = _unify_sort_by_with_order(sort_by, sort)
# END TODO
index = await self._get_dataset_index(dataset)

text_query = self._build_text_query(dataset, text=query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ async def test_with_filter(
offset=0,
limit=50,
query=None,
sort_by=None,
)

async def test_with_sort(
Expand Down Expand Up @@ -368,7 +367,6 @@ async def test_with_sort(
offset=0,
limit=50,
query=None,
sort_by=None,
)

async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict):
Expand Down
Loading
Loading