Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gzip compression middleware #19

Merged
merged 4 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions nerdd_backend/actions/save_result_to_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def _process_message(self, message: ResultMessage) -> None:
try:
job = await self.repository.get_job_by_id(job_id)
except RecordNotFoundError:
logger.error(f"Job with id {job_id} not found. Ignore this result.")
logger.error(f"Job with id {job_id} not found. Ignoring this result.")
return

# TODO: check if corresponding module has correct task type (e.g. "derivative_prediction")
Expand Down Expand Up @@ -55,11 +55,9 @@ async def _replace_source(source_id, repository):
# save result
await self.repository.create_result(Result(id=id, **message.model_dump()))

# update job
# TODO: there might be a RaceCondition here (no atomic transaction)
num_entries_processed = await self.repository.get_num_processed_entries_by_job_id(job_id)
# update set of processed entries in job
await self.repository.update_job(
JobUpdate(id=job_id, num_entries_processed=num_entries_processed)
JobUpdate(id=job_id, entries_processed=[message.mol_id])
)

def _get_group_name(self):
Expand Down
34 changes: 19 additions & 15 deletions nerdd_backend/data/memory_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from nerdd_link.utils import ObservableList

from ..models import Job, JobInternal, JobUpdate, Module, Result, Source
from ..util import CompressedSet
from .exceptions import RecordAlreadyExistsError, RecordNotFoundError
from .repository import Repository

Expand Down Expand Up @@ -78,24 +79,27 @@ async def create_job(self, job: Job) -> JobInternal:
self.jobs.append(result)
return result

async def update_job(self, job: JobUpdate) -> JobInternal:
async def update_job(self, job_update: JobUpdate) -> JobInternal:
async with self.transaction_lock:
existing_job = await self.get_job_by_id(job.id)
existing_job = await self.get_job_by_id(job_update.id)
modified_job = JobInternal(**existing_job.model_dump())
if job.status is not None:
modified_job.status = job.status
if job.num_entries_processed is not None:
modified_job.num_entries_processed = job.num_entries_processed
if job.num_entries_total is not None:
modified_job.num_entries_total = job.num_entries_total
if job.num_checkpoints_total is not None:
modified_job.num_checkpoints_total = job.num_checkpoints_total
if job.new_checkpoints_processed is not None:
modified_job.checkpoints_processed.extend(job.new_checkpoints_processed)
if job.new_output_formats is not None:
modified_job.output_formats.extend(job.new_output_formats)
if job_update.status is not None:
modified_job.status = job_update.status
if job_update.entries_processed is not None:
entries_processed = CompressedSet(modified_job.entries_processed)
for entry in job_update.entries_processed:
entries_processed.add(entry)
modified_job.entries_processed = entries_processed.to_intervals()
if job_update.num_entries_total is not None:
modified_job.num_entries_total = job_update.num_entries_total
if job_update.num_checkpoints_total is not None:
modified_job.num_checkpoints_total = job_update.num_checkpoints_total
if job_update.new_checkpoints_processed is not None:
modified_job.checkpoints_processed.extend(job_update.new_checkpoints_processed)
if job_update.new_output_formats is not None:
modified_job.output_formats.extend(job_update.new_output_formats)
self.jobs.update(existing_job, modified_job)
return await self.get_job_by_id(job.id)
return await self.get_job_by_id(job_update.id)

async def get_job_by_id(self, id: str) -> JobInternal:
try:
Expand Down
2 changes: 1 addition & 1 deletion nerdd_backend/data/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def create_job(self, job: Job) -> JobInternal:
pass

@abstractmethod
async def update_job(self, job: JobUpdate) -> JobInternal:
async def update_job(self, job_update: JobUpdate) -> JobInternal:
pass

@abstractmethod
Expand Down
46 changes: 25 additions & 21 deletions nerdd_backend/data/rethinkdb_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rethinkdb.errors import ReqlOpFailedError

from ..models import Job, JobInternal, JobUpdate, Module, Result, Source
from ..util import CompressedSet
from .exceptions import RecordAlreadyExistsError, RecordNotFoundError
from .repository import Repository

Expand Down Expand Up @@ -181,38 +182,41 @@ async def create_job(self, job: Job) -> JobInternal:
)
return JobInternal(**result["changes"][0]["new_val"])

async def update_job(self, job: JobUpdate) -> JobInternal:
update_set = {}
if job.status is not None:
update_set["status"] = job.status
if job.num_entries_processed is not None:
update_set["num_entries_processed"] = job.num_entries_processed
if job.num_entries_total is not None:
update_set["num_entries_total"] = job.num_entries_total
if job.num_checkpoints_total is not None:
update_set["num_checkpoints_total"] = job.num_checkpoints_total
if job.new_checkpoints_processed is not None:
update_set["checkpoints_processed"] = self.r.row["checkpoints_processed"].union(
job.new_checkpoints_processed
)
if job.new_output_formats is not None:
update_set["output_formats"] = self.r.row["output_formats"].union(
job.new_output_formats
async def update_job(self, job_update: JobUpdate) -> JobInternal:
old_job = await self.get_job_by_id(job_update.id)

new_job = old_job.model_dump()
if job_update.status is not None:
new_job["status"] = job_update.status
if job_update.entries_processed is not None:
entries_processed = CompressedSet(old_job.entries_processed)
for entry in job_update.entries_processed:
entries_processed.add(entry)
new_job["entries_processed"] = entries_processed.to_intervals()
if job_update.num_entries_total is not None:
new_job["num_entries_total"] = job_update.num_entries_total
if job_update.num_checkpoints_total is not None:
new_job["num_checkpoints_total"] = job_update.num_checkpoints_total
if job_update.new_checkpoints_processed is not None:
new_job["checkpoints_processed"] = (
old_job.checkpoints_processed + job_update.new_checkpoints_processed
)
if job_update.new_output_formats is not None:
new_job["output_formats"] = old_job.output_formats + job_update.new_output_formats

changes = (
await self.r.db(self.database_name)
.table("jobs")
.get(job.id)
.update(update_set, return_changes=True)
.get(job_update.id)
.replace(new_job, return_changes=True)
.run(self.connection)
)

if changes["unchanged"] == 1:
return await self.get_job_by_id(job.id)
return old_job

if len(changes["changes"]) == 0:
raise RecordNotFoundError(Job, job.id)
raise RecordNotFoundError(Job, job_update.id)

return JobInternal(**changes["changes"][0]["new_val"])

Expand Down
3 changes: 3 additions & 0 deletions nerdd_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from nerdd_link import FileSystem, KafkaChannel, MemoryChannel, SystemMessage
from nerdd_link.utils import async_to_sync
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -157,6 +158,8 @@ async def global_lifespan(app: FastAPI):
allow_headers=["*"],
)

app.add_middleware(GZipMiddleware)

app.include_router(jobs_router)
app.include_router(sources_router)
app.include_router(results_router)
Expand Down
9 changes: 5 additions & 4 deletions nerdd_backend/models/job.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

from pydantic import BaseModel

Expand All @@ -17,8 +17,9 @@ class Job(BaseModel):
source_id: str
params: dict
created_at: datetime = datetime.now(timezone.utc)
page_size: int = 10
status: str
num_entries_processed: int = 0
entries_processed: List[Tuple[int, int]] = []
num_entries_total: Optional[int] = None


Expand All @@ -35,9 +36,9 @@ class JobCreate(BaseModel):


class JobPublic(Job):
num_entries_processed: int = 0
num_pages_total: Optional[int]
num_pages_processed: int
page_size: int
output_files: List[OutputFile]
job_url: str
results_url: str
Expand All @@ -46,7 +47,7 @@ class JobPublic(Job):
class JobUpdate(BaseModel):
id: str
status: Optional[str] = None
num_entries_processed: Optional[int] = None
entries_processed: Optional[List[int]] = None
num_entries_total: Optional[int] = None
num_checkpoints_total: Optional[int] = None
# checkpoint list update
Expand Down
3 changes: 2 additions & 1 deletion nerdd_backend/routers/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def get_dynamic_router(module: Module):

# all methods will be available at /module_name e.g. /cypstrate
# the parameter tags creates a separate group in the swagger ui
router = APIRouter(tags=[module.id])
# module will be hidden if visible is set to False
router = APIRouter(tags=[module.id], include_in_schema=module.visible)

#
# GET /jobs
Expand Down
27 changes: 20 additions & 7 deletions nerdd_backend/routers/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,25 @@

from ..data import RecordNotFoundError
from ..models import Job, JobCreate, JobInternal, JobPublic, OutputFile
from ..util import CompressedSet

__all__ = ["jobs_router"]

jobs_router = APIRouter(prefix="/jobs")


async def augment_job(job: JobInternal, request: Request) -> JobPublic:
app = request.app
page_size = app.state.config.page_size
# compute number of processed entries
entries_processed = CompressedSet(job.entries_processed)
num_entries_processed = entries_processed.count()

# The number of processed pages is only valid if the computation has not finished yet. We adapt
# this number in the if statement below.
num_pages_processed = job.num_entries_processed // page_size
num_pages_processed = num_entries_processed // job.page_size
if job.num_entries_total is not None:
num_pages_total = math.ceil(job.num_entries_total / page_size)
num_pages_total = math.ceil(job.num_entries_total / job.page_size)

if job.num_entries_total == job.num_entries_processed:
if job.num_entries_total == num_entries_processed:
num_pages_processed = num_pages_total
else:
num_pages_total = None
Expand All @@ -47,9 +49,9 @@ async def augment_job(job: JobInternal, request: Request) -> JobPublic:
**job.model_dump(),
job_url=f"{request.url.netloc}/jobs/{job.id}",
results_url=f"{request.url.netloc}/jobs/{job.id}/results",
num_entries_processed=num_entries_processed,
num_pages_processed=num_pages_processed,
num_pages_total=num_pages_total,
page_size=page_size,
output_files=output_files,
)

Expand All @@ -65,7 +67,7 @@ async def create_job(job: JobCreate = Body(), request: Request = None):

# check if module exists
try:
await repository.get_module_by_id(job.job_type)
module = await repository.get_module_by_id(job.job_type)
except RecordNotFoundError as e:
all_modules = await repository.get_all_modules()
valid_options = [module.id for module in all_modules]
Expand All @@ -77,6 +79,16 @@ async def create_job(job: JobCreate = Body(), request: Request = None):
),
) from e

# get page size (depending on module task)
task = module.task
if task == "atom_property_prediction":
page_size = app.state.config.page_size_atom_property_prediction
elif task == "derivative_property_prediction":
page_size = app.state.config.page_size_derivative_property_prediction
else:
# task == "molecular_property_prediction" or unknown task
page_size = app.state.config.page_size_molecular_property_prediction

# check if source exists
try:
await repository.get_source_by_id(job.source_id)
Expand All @@ -88,6 +100,7 @@ async def create_job(job: JobCreate = Body(), request: Request = None):
job_type=job.job_type,
source_id=job.source_id,
params=job.params,
page_size=page_size,
status="created",
)

Expand Down
5 changes: 3 additions & 2 deletions nerdd_backend/routers/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async def get_modules(request: Request):
return [
ModuleShort(**module.model_dump(), module_url=f"{request.base_url}{module.id}")
for module in modules
if module.visible
]


Expand All @@ -28,7 +29,7 @@ async def get_module(module_id: str, request: Request):

try:
module = await repository.get_module_by_id(module_id)
except RecordNotFoundError:
raise HTTPException(status_code=404, detail="Module not found")
except RecordNotFoundError as e:
raise HTTPException(status_code=404, detail="Module not found") from e

return ModulePublic(**module.model_dump(), module_url=request.url.path)
3 changes: 2 additions & 1 deletion nerdd_backend/routers/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ async def get_results(
) -> ResultSet:
app = request.app
repository: Repository = app.state.repository
page_size = app.state.config.page_size

page_zero_based = page - 1

Expand All @@ -24,6 +23,8 @@ async def get_results(
except RecordNotFoundError as e:
raise HTTPException(status_code=404, detail="Job not found") from e

page_size = job.page_size

# num_entries might not be available, yet
# we assume it to be positive infinity in that case
if job.num_entries_total is None:
Expand Down
7 changes: 4 additions & 3 deletions nerdd_backend/routers/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ async def get_job_ws(websocket: WebSocket, job_id: str):
async def get_results_ws(websocket: WebSocket, job_id: str, page: int = Query()):
app = websocket.app
repository = app.state.repository
page_size = app.state.config.page_size

await websocket.accept()

try:
job = await repository.get_job_by_id(job_id)
except RecordNotFoundError:
raise HTTPException(status_code=404, detail="Job not found")
except RecordNotFoundError as e:
raise HTTPException(status_code=404, detail="Job not found") from e

page_size = job.page_size

# num_entries might not be available, yet
# we assume it to be positive infinity in that case
Expand Down
4 changes: 3 additions & 1 deletion nerdd_backend/settings/development.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ defaults:
host: localhost
port: 8000

page_size: 5
page_size_molecular_property_prediction: 5
page_size_atom_property_prediction: 3
page_size_derivative_property_prediction: 2
media_root: ./media

mock_infra: true
Expand Down
4 changes: 3 additions & 1 deletion nerdd_backend/settings/production.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ defaults:
host: 0.0.0.0
port: 8000

page_size: 100
page_size_molecular_property_prediction: 100
page_size_atom_property_prediction: 10
page_size_derivative_property_prediction: 10
media_root: /data

mock_infra: false
Expand Down
4 changes: 3 additions & 1 deletion nerdd_backend/settings/testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ defaults:
host: localhost
port: 8000

page_size: 5
page_size_molecular_property_prediction: 5
page_size_atom_property_prediction: 3
page_size_derivative_property_prediction: 2
media_root: ./media

mock_infra: true
Expand Down
1 change: 1 addition & 0 deletions nerdd_backend/util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .compressed_set import *
Loading