Skip to content

Commit

Permalink
Fix changes methods in RethinkDbRepository
Browse files Browse the repository at this point in the history
  • Loading branch information
shirte committed Dec 7, 2024
1 parent d5f9979 commit 4b628d9
Showing 1 changed file with 60 additions and 17 deletions.
77 changes: 60 additions & 17 deletions nerdd_backend/data/rethinkdb_repository.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hashlib
import json
from typing import Any, AsyncIterable, Dict, List, Optional
from typing import Any, AsyncIterable, Dict, List, Optional, Tuple

from rethinkdb import RethinkDB
from rethinkdb.errors import ReqlOpFailedError
Expand Down Expand Up @@ -50,14 +50,29 @@ async def initialize(self) -> None:
#
# MODULES
#
async def get_module_changes(self) -> AsyncIterable[Module]:
return (
async def get_module_changes(
self,
) -> AsyncIterable[Tuple[Optional[Module], Optional[Module]]]:
cursor = (
await self.r.db(self.database_name)
.table("modules")
.changes(include_initial=True)
.run(self.connection)
)

async for change in cursor:
if change["old_val"] is None:
old_module = None
else:
old_module = Module(**change["old_val"])

if change["new_val"] is None:
new_module = None
else:
new_module = Module(**change["new_val"])

yield old_module, new_module

async def get_all_modules(self) -> List[Module]:
cursor = (
await self.r.db(self.database_name).table("modules").run(self.connection)
Expand Down Expand Up @@ -87,18 +102,18 @@ async def create_module_table(self) -> None:
except ReqlOpFailedError:
pass

async def get_most_recent_version(self, module_name: str) -> Module:
result = (
await self.r.db(self.database_name)
.table("modules")
.get(module_name)
.run(self.connection)
)
# async def get_most_recent_version(self, module_name: str) -> Module:
# result = (
# await self.r.db(self.database_name)
# .table("modules")
# .get(module_name)
# .run(self.connection)
# )

if result is None:
raise RecordNotFoundError(Module, module_name)
# if result is None:
# raise RecordNotFoundError(Module, module_name)

return Module(**result)
# return Module(**result)

async def upsert_module(self, module: Module) -> None:
# insert the module (or update if it matches an existing name-version combo)
Expand Down Expand Up @@ -259,8 +274,10 @@ async def upsert_result(self, result: Result) -> None:
.run(self.connection)
)

async def get_job_changes(self, job_id: str) -> AsyncIterable[Job]:
return (
async def get_job_changes(
self, job_id: str
) -> AsyncIterable[Tuple[Optional[Job], Optional[Job]]]:
cursor = (
await self.r.db(self.database_name)
.table("results")
.filter(self.r.row["job_id"] == job_id)
Expand All @@ -269,13 +286,26 @@ async def get_job_changes(self, job_id: str) -> AsyncIterable[Job]:
.run(self.connection)
)

async for change in cursor:
if change["old_val"] is None:
old_job = None
else:
old_job = Job(**change["old_val"])

if change["new_val"] is None:
new_job = None
else:
new_job = Job(**change["new_val"])

yield old_job, new_job

async def get_result_changes(
self,
job_id: str,
start_mol_id: Optional[int] = None,
end_mol_id: Optional[int] = None,
) -> AsyncIterable[Result]:
return (
) -> AsyncIterable[Tuple[Optional[Result], Optional[Result]]]:
cursor = (
await self.r.db(self.database_name)
.table("results")
.filter(
Expand All @@ -286,3 +316,16 @@ async def get_result_changes(
.changes(include_initial=False)
.run(self.connection)
)

async for change in cursor:
if change["old_val"] is None:
old_result = None
else:
old_result = Result(**change["old_val"])

if change["new_val"] is None:
new_result = None
else:
new_result = Result(**change["new_val"])

yield old_result, new_result

0 comments on commit 4b628d9

Please sign in to comment.