Skip to content

Commit

Permalink
Move models into submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
shirte committed Dec 11, 2024
1 parent 6dfc46d commit ad176da
Show file tree
Hide file tree
Showing 14 changed files with 30 additions and 71 deletions.
3 changes: 2 additions & 1 deletion nerdd_backend/actions/save_module_to_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from nerdd_link import Action, Channel, ModuleMessage

from ..data import Module, Repository
from ..data import Repository
from ..models import Module

__all__ = ["SaveModuleToDb"]

Expand Down
3 changes: 2 additions & 1 deletion nerdd_backend/actions/save_result_to_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from nerdd_link import Action, Channel, ResultMessage

from ..data import Repository, Result
from ..data import Repository
from ..models import Result

__all__ = ["SaveResultToDb"]

Expand Down
4 changes: 0 additions & 4 deletions nerdd_backend/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from .exceptions import *
from .job import *
from .memory_repository import *
from .module import *
from .repository import *
from .result import *
from .rethinkdb_repository import *
from .source import *
5 changes: 1 addition & 4 deletions nerdd_backend/data/memory_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

from nerdd_link.utils import ObservableList

from ..models import Job, Module, Result, Source
from .exceptions import RecordNotFoundError
from .job import Job
from .module import Module
from .repository import Repository
from .result import Result
from .source import Source

__all__ = ["MemoryRepository"]

Expand Down
9 changes: 2 additions & 7 deletions nerdd_backend/data/repository.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from abc import ABC, abstractmethod
from typing import AsyncIterable, List, Optional, Tuple

from .job import Job
from .module import Module
from .result import Result
from .source import Source
from ..models import Job, Module, Result, Source

__all__ = ["Repository"]

Expand Down Expand Up @@ -48,9 +45,7 @@ async def upsert_module(self, module: Module) -> None:
# JOBS
#
@abstractmethod
def get_job_changes(
self, job_id: str
) -> AsyncIterable[Tuple[Optional[Job], Optional[Job]]]:
def get_job_changes(self, job_id: str) -> AsyncIterable[Tuple[Optional[Job], Optional[Job]]]:
pass

@abstractmethod
Expand Down
40 changes: 7 additions & 33 deletions nerdd_backend/data/rethinkdb_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
from rethinkdb import RethinkDB
from rethinkdb.errors import ReqlOpFailedError

from ..models import Job, Module, Result, Source
from .exceptions import RecordNotFoundError
from .job import Job
from .module import Module
from .repository import Repository
from .result import Result
from .source import Source

__all__ = ["RethinkDbRepository"]

Expand Down Expand Up @@ -74,17 +71,12 @@ async def get_module_changes(
yield old_module, new_module

async def get_all_modules(self) -> List[Module]:
cursor = (
await self.r.db(self.database_name).table("modules").run(self.connection)
)
cursor = await self.r.db(self.database_name).table("modules").run(self.connection)
return [Module(**item) async for item in cursor]

async def get_module_by_id(self, module_id: str) -> Module:
result = (
await self.r.db(self.database_name)
.table("modules")
.get(module_id)
.run(self.connection)
await self.r.db(self.database_name).table("modules").get(module_id).run(self.connection)
)

if result is None:
Expand All @@ -94,11 +86,7 @@ async def get_module_by_id(self, module_id: str) -> Module:

async def create_module_table(self) -> None:
try:
await (
self.r.db(self.database_name)
.table_create("modules")
.run(self.connection)
)
await self.r.db(self.database_name).table_create("modules").run(self.connection)
except ReqlOpFailedError:
pass

Expand Down Expand Up @@ -146,26 +134,15 @@ async def upsert_job(self, job: Job) -> None:
)

async def get_job_by_id(self, job_id: str) -> Job:
result = (
await self.r.db(self.database_name)
.table("jobs")
.get(job_id)
.run(self.connection)
)
result = await self.r.db(self.database_name).table("jobs").get(job_id).run(self.connection)

if result is None:
raise RecordNotFoundError(Job, job_id)

return Job(**result)

async def delete_job_by_id(self, job_id: str) -> None:
await (
self.r.db(self.database_name)
.table("jobs")
.get(job_id)
.delete()
.run(self.connection)
)
await self.r.db(self.database_name).table("jobs").get(job_id).delete().run(self.connection)

#
# SOURCES
Expand All @@ -190,10 +167,7 @@ async def upsert_source(self, source: Source) -> None:

async def get_source_by_id(self, source_id: str) -> Source:
result = (
await self.r.db(self.database_name)
.table("sources")
.get(source_id)
.run(self.connection)
await self.r.db(self.database_name).table("sources").get(source_id).run(self.connection)
)

if result is None:
Expand Down
4 changes: 4 additions & 0 deletions nerdd_backend/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .job import *
from .module import *
from .result import *
from .source import *
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
14 changes: 4 additions & 10 deletions nerdd_backend/routers/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, UploadFile
from pydantic import create_model, model_validator

from ..data import Module
from ..models import Module
from .jobs import CreateJobRequest, create_job, delete_job, get_job
from .results import get_results
from .sources import put_multiple_sources
Expand Down Expand Up @@ -63,9 +63,7 @@ def get_dynamic_router(module: Module):
)
QueryModelPost = create_model(
"QueryModelForm",
__validators__={
"validate_to_json": model_validator(mode="before")(validate_to_json)
},
__validators__={"validate_to_json": model_validator(mode="before")(validate_to_json)},
inputs=(List[str], []),
sources=(List[str], []),
**field_definitions,
Expand Down Expand Up @@ -154,11 +152,7 @@ async def create_complex_job(
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}")(get_job_ws)
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/")(get_job_ws)

router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/results")(
get_results_ws
)
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/results/")(
get_results_ws
)
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/results")(get_results_ws)
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/results/")(get_results_ws)

return router
15 changes: 5 additions & 10 deletions nerdd_backend/routers/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from fastapi import APIRouter, HTTPException, Request, UploadFile
from fastapi.encoders import jsonable_encoder

from ..data import RecordNotFoundError, Repository, Source
from ..data import RecordNotFoundError, Repository
from ..models import Source

sources_router = APIRouter(prefix="/sources")

Expand All @@ -33,9 +34,7 @@ async def _put_input(input: str):
file = UploadFile(file_stream)
return await put_source(request, file=file)

sources_from_inputs = await asyncio.gather(
*[_put_input(input) for input in inputs]
)
sources_from_inputs = await asyncio.gather(*[_put_input(input) for input in inputs])
all_sources += sources_from_inputs

for source_id in sources:
Expand All @@ -44,17 +43,13 @@ async def _put_input(input: str):
all_sources.append(source)

# create one json file referencing all sources
sources_from_files = await asyncio.gather(
*[put_source(request, file=file) for file in files]
)
sources_from_files = await asyncio.gather(*[put_source(request, file=file) for file in files])
all_sources += sources_from_files

all_sources_objects = [source.model_dump() for source in all_sources]

# create a merged file with all sources
file_stream = BytesIO(
json.dumps(jsonable_encoder(all_sources_objects)).encode("utf-8")
)
file_stream = BytesIO(json.dumps(jsonable_encoder(all_sources_objects)).encode("utf-8"))
file = UploadFile(file_stream, filename="input.json")
result_source = await put_source(request, file=file, format="json")

Expand Down
4 changes: 3 additions & 1 deletion tests/steps/repository.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest_asyncio
from nerdd_backend.data import MemoryRepository, Module
from nerdd_module.tests import MolWeightModel
from pytest_bdd import given

from nerdd_backend.data import MemoryRepository
from nerdd_backend.models import Module

from .async_step import async_step


Expand Down

0 comments on commit ad176da

Please sign in to comment.