From 05a3431e6a6e9f84c7dc16a5dbc92f49cd8d55a5 Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Thu, 12 Dec 2024 09:07:21 +0100 Subject: [PATCH] Adapt sources router --- nerdd_backend/routers/sources.py | 100 ++++++++++++++++--------------- 1 file changed, 53 insertions(+), 47 deletions(-) diff --git a/nerdd_backend/routers/sources.py b/nerdd_backend/routers/sources.py index d81cbe9..41c6200 100644 --- a/nerdd_backend/routers/sources.py +++ b/nerdd_backend/routers/sources.py @@ -10,55 +10,16 @@ from fastapi.encoders import jsonable_encoder from ..data import RecordNotFoundError, Repository -from ..models import Source +from ..models import Source, SourcePublic -sources_router = APIRouter(prefix="/sources") - - -async def put_multiple_sources( - inputs: List[str], - sources: List[str], - files: List[UploadFile], - request: Request, -): - app = request.app - repository: Repository = app.state.repository - - all_sources = [] - - # create source from inputs list - if len(inputs) > 0: - - async def _put_input(input: str): - file_stream = BytesIO(input.encode("utf-8")) - file = UploadFile(file_stream) - return await put_source(request, file=file) +__all__ = ["sources_router", "put_multiple_sources"] - sources_from_inputs = await asyncio.gather(*[_put_input(input) for input in inputs]) - all_sources += sources_from_inputs - - for source_id in sources: - source = await repository.get_source_by_id(source_id) - if source is not None: - all_sources.append(source) - - # create one json file referencing all sources - sources_from_files = await asyncio.gather(*[put_source(request, file=file) for file in files]) - all_sources += sources_from_files - - all_sources_objects = [source.model_dump() for source in all_sources] - - # create a merged file with all sources - file_stream = BytesIO(json.dumps(jsonable_encoder(all_sources_objects)).encode("utf-8")) - file = UploadFile(file_stream, filename="input.json") - result_source = await put_source(request, file=file, format="json") - - return result_source +sources_router = APIRouter(prefix="/sources") @sources_router.put("", include_in_schema=False) @sources_router.put("/") -async def put_source(request: Request, file: UploadFile, format: Optional[str] = None): +async def put_source(file: UploadFile, format: Optional[str] = None, request: Request = None): app = request.app repository: Repository = app.state.repository media_root = app.state.config.media_root @@ -67,6 +28,7 @@ async def put_source(request: Request, file: UploadFile, format: Optional[str] = uuid = uuid4() # create path to new file + # TODO: use FileSystem path = os.path.join(media_root, "sources", str(uuid)) os.makedirs(os.path.dirname(path), exist_ok=True) @@ -83,11 +45,11 @@ async def put_source(request: Request, file: UploadFile, format: Optional[str] = ) await repository.upsert_source(source) - return source + return SourcePublic(**source.model_dump()) @sources_router.get("/{uuid}") -async def get_source(request: Request, uuid: str): +async def get_source(uuid: str, request: Request): app = request.app repository: Repository = app.state.repository try: @@ -95,11 +57,11 @@ async def get_source(request: Request, uuid: str): except RecordNotFoundError as e: raise HTTPException(status_code=404, detail="Source not found") from e - return source + return SourcePublic(**source.model_dump()) @sources_router.delete("/{uuid}") -async def delete_source(request: Request, uuid: str): +async def delete_source(uuid: str, request: Request): app = request.app repository: Repository = app.state.repository media_root = app.state.config.media_root @@ -110,8 +72,52 @@ async def delete_source(request: Request, uuid: str): raise HTTPException(status_code=404, detail="Source not found") from e # delete file from disk + # TODO: use FileSystem path = os.path.join(media_root, "sources", str(uuid)) os.remove(path) # delete source from database await repository.delete_source_by_id(uuid) + + return {"message": "Source deleted successfully"} + + +async def put_multiple_sources( + inputs: List[str], + sources: List[str], + files: List[UploadFile], + request: Request, +): + app = request.app + repository: Repository = app.state.repository + + all_sources = [] + + # create source from inputs list + if len(inputs) > 0: + + async def _put_input(input: str): + file_stream = BytesIO(input.encode("utf-8")) + file = UploadFile(file_stream) + return await put_source(file=file, request=request) + + sources_from_inputs = await asyncio.gather(*[_put_input(input) for input in inputs]) + all_sources += sources_from_inputs + + for source_id in sources: + source = await repository.get_source_by_id(source_id) + if source is not None: + all_sources.append(source) + + # create one json file referencing all sources + sources_from_files = await asyncio.gather(*[put_source(request, file=file) for file in files]) + all_sources += sources_from_files + + all_sources_objects = [source.model_dump() for source in all_sources] + + # create a merged file with all sources + file_stream = BytesIO(json.dumps(jsonable_encoder(all_sources_objects)).encode("utf-8")) + file = UploadFile(file_stream, filename="input.json") + result_source = await put_source(file=file, format="json", request=request) + + return result_source