Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] argilla-server: remove deprecated records endpoint #5206

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ These are the section headers that we use:
### Changed

- Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126))
- [breaking] Change `GET /datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140))
- [breaking] Change `GET /me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140))
- [breaking] Change `GET /api/v1/datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140))
- [breaking] Change `GET /api/v1/me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140))
frascuchon marked this conversation as resolved.
Show resolved Hide resolved

### Fixed

- Fixed SQLite connection settings not working correctly due to a outdated conditional. ([#5149](https://github.com/argilla-io/argilla/pull/5149))

### Removed

- [breaking] Remove deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206))
- [breaking] Remove deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206))

## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1)

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
RecordFilterScope,
RecordIncludeParam,
Records,
RecordsCreate,
RecordsUpdate,
SearchRecord,
SearchRecordsQuery,
SearchRecordsResult,
Expand Down Expand Up @@ -424,71 +422,6 @@ async def list_dataset_records(
return Records(items=records, total=total)


@router.post(
"/datasets/{dataset_id}/records",
status_code=status.HTTP_204_NO_CONTENT,
deprecated=True,
description="Deprecated in favor of POST /datasets/{dataset_id}/records/bulk",
)
async def create_dataset_records(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
dataset_id: UUID,
records_create: RecordsCreate,
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(
db,
dataset_id,
options=[
selectinload(Dataset.fields),
selectinload(Dataset.questions),
selectinload(Dataset.metadata_properties),
selectinload(Dataset.vectors_settings),
],
)

await authorize(current_user, DatasetPolicy.create_records(dataset))

await datasets.create_records(db, search_engine, dataset, records_create)

telemetry_client.track_data(action="DatasetRecordsCreated", data={"records": len(records_create.items)})


@router.patch(
"/datasets/{dataset_id}/records",
status_code=status.HTTP_204_NO_CONTENT,
deprecated=True,
description="Deprecated in favor of PUT /datasets/{dataset_id}/records/bulk",
)
async def update_dataset_records(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
dataset_id: UUID,
records_update: RecordsUpdate,
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(
db,
dataset_id,
options=[
selectinload(Dataset.fields),
selectinload(Dataset.questions),
selectinload(Dataset.metadata_properties),
],
)

await authorize(current_user, DatasetPolicy.update_records(dataset))

await datasets.update_records(db, search_engine, dataset, records_update)

telemetry_client.track_data(action="DatasetRecordsUpdated", data={"records": len(records_update.items)})


@router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT)
async def delete_dataset_records(
*,
Expand Down
155 changes: 7 additions & 148 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import sqlalchemy
from fastapi.encoders import jsonable_encoder
from sqlalchemy import Select, and_, case, func, select
from sqlalchemy import Select, and_, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, joinedload, selectinload

Expand All @@ -42,8 +42,6 @@
from argilla_server.api.schemas.v1.records import (
RecordCreate,
RecordIncludeParam,
RecordsCreate,
RecordsUpdate,
RecordUpdateWithId,
)
from argilla_server.api.schemas.v1.responses import (
Expand All @@ -60,7 +58,7 @@
)
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
from argilla_server.contexts import accounts, distribution
from argilla_server.enums import DatasetStatus, RecordInclude, UserRole, RecordStatus
from argilla_server.enums import DatasetStatus, UserRole, RecordStatus
from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError
from argilla_server.models import (
Dataset,
Expand All @@ -74,7 +72,6 @@
User,
Vector,
VectorSettings,
Workspace,
)
from argilla_server.models.suggestions import SuggestionCreateWithRecordId
from argilla_server.search_engine import SearchEngine
Expand All @@ -87,9 +84,6 @@
from argilla_server.validators.suggestions import SuggestionCreateValidator

if TYPE_CHECKING:
from argilla_server.api.schemas.v1.datasets import (
DatasetUpdate,
)
from argilla_server.api.schemas.v1.fields import FieldUpdate
from argilla_server.api.schemas.v1.records import RecordUpdate
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
Expand Down Expand Up @@ -231,7 +225,8 @@ async def create_metadata_property(
) -> MetadataProperty:
if await MetadataProperty.get_by(db, name=metadata_property_create.name, dataset_id=dataset.id):
raise NotUniqueError(
f"Metadata property with name `{metadata_property_create.name}` already exists for dataset with id `{dataset.id}`"
f"Metadata property with name `{metadata_property_create.name}` already exists "
f"for dataset with id `{dataset.id}`"
)

async with db.begin_nested():
Expand Down Expand Up @@ -292,7 +287,8 @@ async def create_vector_settings(

if await VectorSettings.get_by(db, name=vector_settings_create.name, dataset_id=dataset.id):
raise NotUniqueError(
f"Vector settings with name `{vector_settings_create.name}` already exists for dataset with id `{dataset.id}`"
f"Vector settings with name `{vector_settings_create.name}` already exists "
f"for dataset with id `{dataset.id}`"
)

async with db.begin_nested():
Expand Down Expand Up @@ -403,7 +399,7 @@ async def get_user_dataset_metrics(db: AsyncSession, user_id: UUID, dataset_id:
.filter(
Record.dataset_id == dataset_id,
Record.status == RecordStatus.pending,
Response.id == None,
Response.id == None, # noqa
),
),
)
Expand Down Expand Up @@ -549,57 +545,6 @@ async def _build_record(
)


async def create_records(
db: AsyncSession, search_engine: SearchEngine, dataset: Dataset, records_create: RecordsCreate
):
if not dataset.is_ready:
raise UnprocessableEntityError("Records cannot be created for a non published dataset")

records = []

caches = {
"users_ids_cache": set(),
"questions_cache": {},
"metadata_properties_cache": {},
"vectors_settings_cache": {},
}

for record_i, record_create in enumerate(records_create.items):
try:
record = await _build_record(db, dataset, record_create, caches)

record.responses = await _build_record_responses(
db, record, record_create.responses, caches["users_ids_cache"]
)

record.suggestions = await _build_record_suggestions(
db, record, record_create.suggestions, caches["questions_cache"]
)

record.vectors = await _build_record_vectors(
db,
dataset,
record_create.vectors,
build_vector_func=lambda value, vector_settings_id: Vector(
value=value, vector_settings_id=vector_settings_id
),
cache=caches["vectors_settings_cache"],
)

except (UnprocessableEntityError, ValueError) as e:
raise UnprocessableEntityError(f"Record at position {record_i} is not valid because {e}") from e

records.append(record)

async with db.begin_nested():
db.add_all(records)
await db.flush(records)
await _preload_records_relationships_before_index(db, records)
await search_engine.index_records(dataset, records)

await db.commit()


async def _load_users_from_responses(responses: Union[Response, Iterable[Response]]) -> None:
if isinstance(responses, Response):
responses = [responses]
Expand Down Expand Up @@ -808,92 +753,6 @@ async def preload_records_relationships_before_validate(db: AsyncSession, record
)


async def update_records(
db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, records_update: "RecordsUpdate"
) -> None:
records_ids = [record_update.id for record_update in records_update.items]

if len(records_ids) != len(set(records_ids)):
raise UnprocessableEntityError("Found duplicate records IDs")

existing_records_ids = await _exists_records_with_ids(db, dataset_id=dataset.id, records_ids=records_ids)
non_existing_records_ids = set(records_ids) - set(existing_records_ids)

if len(non_existing_records_ids) > 0:
sorted_non_existing_records_ids = sorted(non_existing_records_ids, key=lambda x: records_ids.index(x))
records_str = ", ".join([str(record_id) for record_id in sorted_non_existing_records_ids])
raise UnprocessableEntityError(f"Found records that do not exist: {records_str}")

# Lists to store the records that will be updated in the database or in the search engine
records_update_objects: List[Dict[str, Any]] = []
records_search_engine_update: List[UUID] = []
records_delete_suggestions: List[UUID] = []

# Cache dictionaries to avoid querying the database multiple times
caches = {
"metadata_properties": {},
"questions": {},
"vector_settings": {},
}

existing_records = await get_records_by_ids(db, records_ids=records_ids, dataset_id=dataset.id)

suggestions = []
upsert_vectors = []
for record_i, (record_update, record) in enumerate(zip(records_update.items, existing_records)):
try:
params, record_suggestions, record_vectors, needs_search_engine_update, caches = await _build_record_update(
db, record, record_update, caches
)

if record_suggestions is not None:
suggestions.extend(record_suggestions)
records_delete_suggestions.append(record_update.id)

upsert_vectors.extend(record_vectors)

if needs_search_engine_update:
records_search_engine_update.append(record_update.id)

# Only update the record if there are params to update
if len(params) > 1:
records_update_objects.append(params)
except (UnprocessableEntityError, ValueError) as e:
raise UnprocessableEntityError(f"Record at position {record_i} is not valid because {e}") from e

async with db.begin_nested():
if records_delete_suggestions:
params = [Suggestion.record_id.in_(records_delete_suggestions)]
await Suggestion.delete_many(db, params=params, autocommit=False)

if suggestions:
db.add_all(suggestions)

if upsert_vectors:
await Vector.upsert_many(
db,
objects=upsert_vectors,
constraints=[Vector.record_id, Vector.vector_settings_id],
autocommit=False,
)

if records_update_objects:
await Record.update_many(db, records_update_objects, autocommit=False)

if records_search_engine_update:
records = await get_records_by_ids(
db,
dataset_id=dataset.id,
records_ids=records_search_engine_update,
include=RecordIncludeParam(keys=[RecordInclude.vectors], vectors=None),
)
await dataset.awaitable_attrs.vectors_settings
await _preload_records_relationships_before_index(db, records)
await search_engine.index_records(dataset, records)

await db.commit()


async def delete_records(
db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, records_ids: List[UUID]
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@


@pytest.mark.asyncio
class TestCreateDatasetRecords:
class TestCreateDatasetRecordsInBulk:
def url(self, dataset_id: UUID) -> str:
return f"/api/v1/datasets/{dataset_id}/records"
return f"/api/v1/datasets/{dataset_id}/records/bulk"

async def test_create_dataset_records(
self, async_client: AsyncClient, db: AsyncSession, owner: User, owner_auth_header: dict
Expand Down Expand Up @@ -209,7 +209,7 @@ async def test_create_dataset_records(
},
)

assert response.status_code == 204
assert response.status_code == 201

assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 1
assert (await db.execute(select(func.count(Response.id)))).scalar_one() == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@


@pytest.mark.asyncio
class TestUpdateDatasetRecords:
class TestUpdateDatasetRecordsInBulk:
def url(self, dataset_id: UUID) -> str:
return f"/api/v1/datasets/{dataset_id}/records"
return f"/api/v1/datasets/{dataset_id}/records/bulk"

async def test_update_dataset_records(
self, async_client: AsyncClient, db: AsyncSession, owner: User, owner_auth_header: dict
Expand Down Expand Up @@ -121,7 +121,7 @@ async def test_update_dataset_records(
dataset=dataset,
)

response = await async_client.patch(
response = await async_client.put(
self.url(dataset.id),
headers=owner_auth_header,
json={
Expand Down Expand Up @@ -180,7 +180,7 @@ async def test_update_dataset_records(
},
)

assert response.status_code == 204
assert response.status_code == 200

assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 1
assert (await db.execute(select(func.count(Suggestion.id)))).scalar_one() == 6
Loading
Loading