Skip to content

Commit

Permalink
✨ Add endpoints to manage quality check requests
Browse files Browse the repository at this point in the history
  • Loading branch information
agmangas committed May 6, 2024
1 parent fc70e18 commit f65def7
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 166 deletions.
39 changes: 16 additions & 23 deletions moderate_api/entities/asset/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from typing import Dict, List, Optional, Union

from pydantic import validator
from sqlalchemy import Column, Index, Text, cast, or_, select
from sqlalchemy import Column, Index, Text, or_, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Field, Relationship, SQLModel

from moderate_api.db import AsyncSessionDep
from moderate_api.entities.crud import find_by_json_key, update_json_key
from moderate_api.object_storage import S3ClientDep


Expand Down Expand Up @@ -133,30 +133,23 @@ async def get_s3object_size_mib(s3_object: UploadedS3Object, s3: S3ClientDep) ->
async def find_s3object_pending_quality_check(
session: AsyncSessionDep,
) -> List[UploadedS3Object]:
stmt = select(UploadedS3Object).filter(
UploadedS3Object.meta[S3ObjectWellKnownMetaKeys.PENDING_QUALITY_CHECK.value]
== cast(True, JSONB)
return await find_by_json_key(
sql_model=UploadedS3Object,
session=session,
json_column="meta",
json_key=S3ObjectWellKnownMetaKeys.PENDING_QUALITY_CHECK.value,
json_value=True,
)

result = await session.execute(stmt)
s3objects: List[UploadedS3Object] = result.scalars().all()

return s3objects


async def update_s3object_quality_check_flag(
ids: List[int], session: AsyncSessionDep, value: bool
):
ids = ids if isinstance(ids, list) else [ids]
stmt = select(UploadedS3Object).where(UploadedS3Object.id.in_(ids))
result = await session.execute(stmt)
s3objects = result.scalars().all()

for s3object in s3objects:
meta = s3object.meta or {}
meta.update({S3ObjectWellKnownMetaKeys.PENDING_QUALITY_CHECK.value: value})
s3object.meta = meta
flag_modified(s3object, "meta")
session.add(s3object)

await session.commit()
return await update_json_key(
sql_model=UploadedS3Object,
session=session,
primary_keys=ids,
json_column="meta",
json_key=S3ObjectWellKnownMetaKeys.PENDING_QUALITY_CHECK.value,
json_value=value,
)
50 changes: 50 additions & 0 deletions moderate_api/entities/asset/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
AssetUpdate,
UploadedS3Object,
find_s3object_by_key_or_id,
find_s3object_pending_quality_check,
update_s3object_quality_check_flag,
)
from moderate_api.entities.crud import (
CrudFiltersQuery,
Expand Down Expand Up @@ -111,6 +113,54 @@ async def get_asset_presigned_urls(
return ret


class ObjectPendingQuality(BaseModel):
key: str
asset_id: int
id: int


@router.get(
"/object/quality-check", response_model=List[ObjectPendingQuality], tags=[_TAG]
)
async def get_asset_objects_pending_quality(
*,
user: UserDep,
session: AsyncSessionDep,
):
"""Retrieves the list of asset objects that are pending quality check."""

if not user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)

s3objs = await find_s3object_pending_quality_check(session=session)

return [ObjectPendingQuality(**full_object.model_dump()) for full_object in s3objs]


class AssetObjectFlagQualityRequest(BaseModel):
asset_object_id: Union[List[int], int]
pending_quality_check: bool


@router.post("/object/quality-check", response_model=List[int], tags=[_TAG])
async def flag_asset_objects_quality_check(
*,
user: UserDep,
session: AsyncSessionDep,
body: AssetObjectFlagQualityRequest,
):
"""Update the quality check flag for a list of asset objects."""

if not user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)

await update_s3object_quality_check_flag(
ids=body.asset_object_id, session=session, value=body.pending_quality_check
)

return body.asset_object_id


async def _download_asset(
*,
user: OptionalUserDep,
Expand Down
40 changes: 40 additions & 0 deletions moderate_api/entities/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from fastapi import HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import Text, asc, cast, desc, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import BinaryExpression, UnaryExpression
from sqlmodel import SQLModel, select
from sqlalchemy.orm.attributes import flag_modified

from moderate_api.authz import User
from moderate_api.enums import Actions, Entities
Expand Down Expand Up @@ -431,6 +433,44 @@ async def delete_one(
return {"ok": True, "id": entity_id}


async def find_by_json_key(
sql_model: Type[SQLModel],
session: AsyncSession,
json_column: str,
json_key: str,
json_value: Any,
) -> List[SQLModel]:
stmt = select(sql_model).filter(
getattr(sql_model, json_column)[json_key] == cast(json_value, JSONB)
)

result = await session.execute(stmt)
return result.scalars().all()


async def update_json_key(
sql_model: Type[SQLModel],
session: AsyncSession,
primary_keys: Union[List[int], int],
json_column: str,
json_key: str,
json_value: Any,
):
ids = primary_keys if isinstance(primary_keys, list) else [primary_keys]
stmt = select(sql_model).where(sql_model.id.in_(ids))
result = await session.execute(stmt)
rows = result.scalars().all()

for row in rows:
jsonobj = getattr(row, json_column) or {}
jsonobj.update({json_key: json_value})
setattr(row, json_column, jsonobj)
flag_modified(row, json_column)
session.add(row)

await session.commit()


_example_crud_filters = json.dumps(
[
["the_date", "lte", arrow.utcnow().naive.isoformat()],
Expand Down
83 changes: 81 additions & 2 deletions tests/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
import httpx
import pytest
from fastapi.testclient import TestClient
from sqlmodel import select

from moderate_api.entities.asset.models import AssetCreate
from moderate_api.db import with_session
from moderate_api.entities.asset.models import (
AssetCreate,
UploadedS3Object,
find_s3object_pending_quality_check,
update_s3object_quality_check_flag,
)
from moderate_api.main import app
from tests.utils import create_asset
from tests.utils import create_asset, upload_test_files

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,3 +47,75 @@ async def test_auto_uuid(access_token):
_logger.debug("Response:\n%s", pprint.pformat(resp_json))
assert response.raise_for_status()
assert resp_json["uuid"]


@pytest.mark.asyncio
async def test_asset_object_quality_check(access_token):
asset_id = upload_test_files(access_token, num_files=4)

async with with_session() as session:
stmt = select(UploadedS3Object).where(UploadedS3Object.asset_id == asset_id)
result = await session.execute(stmt)
s3objects = result.scalars().all()
s3obj_ids = [obj.id for obj in s3objects]

pending = await find_s3object_pending_quality_check(session=session)
assert not pending or len(pending) == 0

await update_s3object_quality_check_flag(
session=session, ids=s3obj_ids[:-1], value=True
)

pending = await find_s3object_pending_quality_check(session=session)
assert len(pending) == (len(s3objects) - 1)

await update_s3object_quality_check_flag(
session=session, ids=s3obj_ids[0], value=False
)

pending = await find_s3object_pending_quality_check(session=session)
assert len(pending) == (len(s3objects) - 2)


@pytest.mark.parametrize(
"access_token",
[{"is_admin": True}],
indirect=True,
)
@pytest.mark.asyncio
async def test_asset_object_quality_check_endpoints(access_token):
asset_id = upload_test_files(access_token, num_files=4)

async with with_session() as session:
stmt = select(UploadedS3Object).where(UploadedS3Object.asset_id == asset_id)
result = await session.execute(stmt)
s3objects = result.scalars().all()
s3obj_ids = [obj.id for obj in s3objects]

headers = {"Authorization": f"Bearer {access_token}"}

with TestClient(app) as client:
resp_get_before = client.get("/asset/object/quality-check", headers=headers)
assert resp_get_before.raise_for_status()
resp_json = resp_get_before.json()
_logger.info("Response:\n%s", pprint.pformat(resp_json))
assert len(resp_json) == 0

num_flagged = 2

resp_post = client.post(
"/asset/object/quality-check",
headers=headers,
json={
"asset_object_id": s3obj_ids[:num_flagged],
"pending_quality_check": True,
},
)

assert resp_post.raise_for_status()

resp_get_after = client.get("/asset/object/quality-check", headers=headers)
assert resp_get_after.raise_for_status()
resp_json = resp_get_after.json()
_logger.info("Response:\n%s", pprint.pformat(resp_json))
assert len(resp_json) == num_flagged
Loading

0 comments on commit f65def7

Please sign in to comment.