Skip to content

Commit

Permalink
Merge pull request #14 from shirte/main
Browse files Browse the repository at this point in the history
Add extra models
  • Loading branch information
shirte authored Dec 14, 2024
2 parents 6dfc46d + d17881c commit dafefdc
Show file tree
Hide file tree
Showing 23 changed files with 234 additions and 220 deletions.
3 changes: 2 additions & 1 deletion nerdd_backend/actions/save_module_to_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from nerdd_link import Action, Channel, ModuleMessage

from ..data import Module, Repository
from ..data import Repository
from ..models import Module

__all__ = ["SaveModuleToDb"]

Expand Down
3 changes: 2 additions & 1 deletion nerdd_backend/actions/save_result_to_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from nerdd_link import Action, Channel, ResultMessage

from ..data import Repository, Result
from ..data import Repository
from ..models import Result

__all__ = ["SaveResultToDb"]

Expand Down
3 changes: 2 additions & 1 deletion nerdd_backend/actions/update_job_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ async def _process_message(self, message: LogMessage) -> None:
job = await self.repository.get_job_by_id(message.job_id)

# update job size
job.num_entries_total = message.size
job.num_entries_total = message.num_entries
job.num_checkpoints_total = message.num_checkpoints

await self.repository.upsert_job(job)

Expand Down
4 changes: 0 additions & 4 deletions nerdd_backend/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from .exceptions import *
from .job import *
from .memory_repository import *
from .module import *
from .repository import *
from .result import *
from .rethinkdb_repository import *
from .source import *
16 changes: 0 additions & 16 deletions nerdd_backend/data/job.py

This file was deleted.

5 changes: 1 addition & 4 deletions nerdd_backend/data/memory_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

from nerdd_link.utils import ObservableList

from ..models import Job, Module, Result, Source
from .exceptions import RecordNotFoundError
from .job import Job
from .module import Module
from .repository import Repository
from .result import Result
from .source import Source

__all__ = ["MemoryRepository"]

Expand Down
18 changes: 0 additions & 18 deletions nerdd_backend/data/module.py

This file was deleted.

9 changes: 2 additions & 7 deletions nerdd_backend/data/repository.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from abc import ABC, abstractmethod
from typing import AsyncIterable, List, Optional, Tuple

from .job import Job
from .module import Module
from .result import Result
from .source import Source
from ..models import Job, Module, Result, Source

__all__ = ["Repository"]

Expand Down Expand Up @@ -48,9 +45,7 @@ async def upsert_module(self, module: Module) -> None:
# JOBS
#
@abstractmethod
def get_job_changes(
self, job_id: str
) -> AsyncIterable[Tuple[Optional[Job], Optional[Job]]]:
def get_job_changes(self, job_id: str) -> AsyncIterable[Tuple[Optional[Job], Optional[Job]]]:
pass

@abstractmethod
Expand Down
11 changes: 0 additions & 11 deletions nerdd_backend/data/result.py

This file was deleted.

40 changes: 7 additions & 33 deletions nerdd_backend/data/rethinkdb_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
from rethinkdb import RethinkDB
from rethinkdb.errors import ReqlOpFailedError

from ..models import Job, Module, Result, Source
from .exceptions import RecordNotFoundError
from .job import Job
from .module import Module
from .repository import Repository
from .result import Result
from .source import Source

__all__ = ["RethinkDbRepository"]

Expand Down Expand Up @@ -74,17 +71,12 @@ async def get_module_changes(
yield old_module, new_module

async def get_all_modules(self) -> List[Module]:
cursor = (
await self.r.db(self.database_name).table("modules").run(self.connection)
)
cursor = await self.r.db(self.database_name).table("modules").run(self.connection)
return [Module(**item) async for item in cursor]

async def get_module_by_id(self, module_id: str) -> Module:
result = (
await self.r.db(self.database_name)
.table("modules")
.get(module_id)
.run(self.connection)
await self.r.db(self.database_name).table("modules").get(module_id).run(self.connection)
)

if result is None:
Expand All @@ -94,11 +86,7 @@ async def get_module_by_id(self, module_id: str) -> Module:

async def create_module_table(self) -> None:
try:
await (
self.r.db(self.database_name)
.table_create("modules")
.run(self.connection)
)
await self.r.db(self.database_name).table_create("modules").run(self.connection)
except ReqlOpFailedError:
pass

Expand Down Expand Up @@ -146,26 +134,15 @@ async def upsert_job(self, job: Job) -> None:
)

async def get_job_by_id(self, job_id: str) -> Job:
result = (
await self.r.db(self.database_name)
.table("jobs")
.get(job_id)
.run(self.connection)
)
result = await self.r.db(self.database_name).table("jobs").get(job_id).run(self.connection)

if result is None:
raise RecordNotFoundError(Job, job_id)

return Job(**result)

async def delete_job_by_id(self, job_id: str) -> None:
await (
self.r.db(self.database_name)
.table("jobs")
.get(job_id)
.delete()
.run(self.connection)
)
await self.r.db(self.database_name).table("jobs").get(job_id).delete().run(self.connection)

#
# SOURCES
Expand All @@ -190,10 +167,7 @@ async def upsert_source(self, source: Source) -> None:

async def get_source_by_id(self, source_id: str) -> Source:
result = (
await self.r.db(self.database_name)
.table("sources")
.get(source_id)
.run(self.connection)
await self.r.db(self.database_name).table("sources").get(source_id).run(self.connection)
)

if result is None:
Expand Down
4 changes: 4 additions & 0 deletions nerdd_backend/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .job import *
from .module import *
from .result import *
from .source import *
32 changes: 32 additions & 0 deletions nerdd_backend/models/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from datetime import datetime
from typing import Any, Dict, Optional

from pydantic import BaseModel

__all__ = ["Job", "JobCreate", "JobPublic"]


class Job(BaseModel):
id: str
job_type: str
source_id: str
params: dict
created_at: datetime = datetime.now()
status: str
num_entries_total: Optional[int] = None
num_checkpoints_total: Optional[int] = None


class JobCreate(BaseModel):
job_type: str
source_id: str
params: Dict[str, Any]


class JobPublic(Job):
num_entries_processed: int
num_pages_total: Optional[int]
num_pages_processed: int
page_size: int
job_url: str
results_url: str
36 changes: 36 additions & 0 deletions nerdd_backend/models/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

from nerdd_module.config import Module as NerddModule
from pydantic import BaseModel, computed_field

__all__ = ["Module", "ModulePublic", "ModuleShort"]


class Module(NerddModule):
@computed_field
@property
def id(self) -> str:
# TODO: incorporate versioning
# compute the primary key from name and version
# if "version" in module.keys():
# version = module["version"]
# else:
# version = "1.0.0"
# name = module["name"]
return self.name


class ModulePublic(Module):
module_url: str


class ModuleShort(BaseModel):
id: str
rank: Optional[int] = None
name: Optional[str] = None
version: Optional[str] = None
visible_name: Optional[str] = None
logo: Optional[str] = None
logo_title: Optional[str] = None
logo_caption: Optional[str] = None
module_url: str
31 changes: 31 additions & 0 deletions nerdd_backend/models/result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Optional

from pydantic import BaseModel, ConfigDict

from .job import JobPublic

__all__ = ["Result", "Pagination", "ResultSet"]


class Result(BaseModel):
id: str
job_id: str
mol_id: int

model_config = ConfigDict(extra="allow")


class Pagination(BaseModel):
page: int # 1-based!
page_size: int
is_incomplete: bool
first_mol_id_on_page: int
last_mol_id_on_page: int
previous_url: Optional[str]
next_url: Optional[str]


class ResultSet(BaseModel):
data: List[Result]
job: JobPublic
pagination: Pagination
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel

__all__ = ["Source"]
__all__ = ["Source", "SourcePublic"]


class Source(BaseModel):
Expand All @@ -12,3 +12,7 @@ class Source(BaseModel):
# the filename that was provided by the user
filename: Optional[str] = None
created_at: datetime = datetime.now()


class SourcePublic(Source):
pass
19 changes: 6 additions & 13 deletions nerdd_backend/routers/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, UploadFile
from pydantic import create_model, model_validator

from ..data import Module
from .jobs import CreateJobRequest, create_job, delete_job, get_job
from ..models import JobCreate, Module
from .jobs import create_job, delete_job, get_job
from .results import get_results
from .sources import put_multiple_sources
from .websockets import get_job_ws, get_results_ws
Expand Down Expand Up @@ -63,9 +63,7 @@ def get_dynamic_router(module: Module):
)
QueryModelPost = create_model(
"QueryModelForm",
__validators__={
"validate_to_json": model_validator(mode="before")(validate_to_json)
},
__validators__={"validate_to_json": model_validator(mode="before")(validate_to_json)},
inputs=(List[str], []),
sources=(List[str], []),
**field_definitions,
Expand All @@ -78,7 +76,6 @@ async def _create_job(
params: dict,
request: Request = None,
):
app = request.app
if "job_type" in params and params["job_type"] != module.name:
return HTTPException(
status_code=400,
Expand All @@ -88,7 +85,7 @@ async def _create_job(
result_source = await put_multiple_sources(inputs, sources, files, request)

return await create_job(
request_data=CreateJobRequest(
job=JobCreate(
job_type=module.name,
source_id=result_source.id,
params={k: v for k, v in params.items() if k in field_definitions},
Expand Down Expand Up @@ -154,11 +151,7 @@ async def create_complex_job(
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}")(get_job_ws)
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/")(get_job_ws)

router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/results")(
get_results_ws
)
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/results/")(
get_results_ws
)
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/results")(get_results_ws)
router.websocket(f"/websocket/{module.name}" "/jobs/{job_id}/results/")(get_results_ws)

return router
Loading

0 comments on commit dafefdc

Please sign in to comment.