Skip to content

Commit

Permalink
Implement changefeed methods in MemoryRepository
Browse files Browse the repository at this point in the history
  • Loading branch information
shirte committed Dec 2, 2024
1 parent 8c7fa46 commit ff4e9d9
Showing 1 changed file with 60 additions and 56 deletions.
116 changes: 60 additions & 56 deletions nerdd_backend/data/memory_repository.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import AsyncIterable, List, Optional
from typing import AsyncIterable, List, Optional, Tuple

from nerdd_link.tests import ObservableList

from .exceptions import RecordNotFoundError
from .job import Job
Expand All @@ -12,101 +14,95 @@

class MemoryRepository(Repository):
def __init__(self) -> None:
self.jobs: List[Job] = []
self.modules: List[Module] = []
self.sources: List[Source] = []
self.results: List[Result] = []
pass

#
# INITIALIZATION
#
async def initialize(self) -> None:
pass
self.jobs = ObservableList()
self.modules = ObservableList()
self.sources = ObservableList()
self.results = ObservableList()

#
# MODULES
#
async def get_module_changes(self) -> AsyncIterable:
yield NotImplementedError()
async def get_module_changes(
self,
) -> AsyncIterable[Tuple[Optional[Module], Optional[Module]]]:
async for change in self.modules.changes():
yield change

async def get_all_modules(self) -> List[Module]:
return self.modules
return self.modules.get_items()

async def upsert_module(self, module: Module) -> None:
assert module.id is not None
existing = any(
existing_module.id == module.id for existing_module in self.modules
)
if existing:
existing_module = module
self.modules = [
Module(**module, **existing_module)
if existing_module.id == module.id
else module
for module in self.modules
]
else:
try:
existing_module = await self.get_module_by_id(module.id)
self.modules.update(existing_module, module)
except RecordNotFoundError:
self.modules.append(module)

async def get_module_by_id(self, id: str) -> Module:
try:
return next((module for module in self.modules if module.id == id))
return next(
(module for module in self.modules.get_items() if module.id == id)
)
except StopIteration as e:
raise RecordNotFoundError(Module, id) from e

#
# JOBS
#
async def get_job_changes(self, job_id: str) -> AsyncIterable:
yield NotImplementedError()
async def get_job_changes(
self, job_id: str
) -> AsyncIterable[Tuple[Optional[Job], Optional[Job]]]:
async for old, new in self.jobs.changes():
if (old is not None and old.id == job_id) or (
new is not None and new.id == job_id
):
yield (old, new)

async def upsert_job(self, job: Job) -> None:
try:
existing_job = await self.get_job_by_id(job.id)
self.jobs.update(existing_job, job)
except RecordNotFoundError:
existing_job = None

if existing_job:
self.jobs = [
existing_job if existing_job.id == job.id else job for job in self.jobs
]
else:
self.jobs.append(job)

async def get_job_by_id(self, id: str) -> Job:
try:
return next((job for job in self.jobs if job.id == id))
return next((job for job in self.jobs.get_items() if job.id == id))
except StopIteration as e:
raise RecordNotFoundError(Job, id) from e

async def delete_job_by_id(self, id: str) -> None:
self.jobs = [job for job in self.jobs if job.id != id]
job = await self.get_job_by_id(id)
self.jobs.remove(job)

#
# SOURCES
#
async def upsert_source(self, source: Source) -> None:
try:
existing_source = await self.get_source_by_id(source.id)
self.sources.update(existing_source, source)
except RecordNotFoundError:
existing_source = None

if existing_source:
self.sources = [
existing_source if existing_source.id == source.id else source
for source in self.sources
]
else:
self.sources.append(source)

async def get_source_by_id(self, id: str) -> Source:
try:
return next((source for source in self.sources if source.id == id))
return next(
(source for source in self.sources.get_items() if source.id == id)
)
except StopIteration as e:
raise RecordNotFoundError(Source, id) from e

async def delete_source_by_id(self, id: str) -> None:
self.sources = [source for source in self.sources if source.id != id]
source = await self.get_source_by_id(id)
self.sources.remove(source)

#
# RESULTS
Expand All @@ -116,12 +112,25 @@ async def get_result_changes(
job_id: str,
start_mol_id: Optional[int] = None,
end_mol_id: Optional[int] = None,
) -> AsyncIterable:
yield NotImplementedError()
) -> AsyncIterable[Tuple[Optional[Result], Optional[Result]]]:
async for change in self.results.changes():
old, new = change
if (
old is not None
and old.job_id == job_id
and start_mol_id <= old.mol_id <= end_mol_id
) or (
new is not None
and new.job_id == job_id
and start_mol_id <= new.mol_id <= end_mol_id
):
yield change

async def get_result_by_id(self, id: str) -> Result:
try:
return next((result for result in self.results if result.id == id))
return next(
(result for result in self.results.get_items() if result.id == id)
)
except StopIteration as e:
raise RecordNotFoundError(Result, id) from e

Expand All @@ -133,26 +142,21 @@ async def get_results_by_job_id(
) -> List[Result]:
return [
result
for result in self.results
for result in self.results.get_items()
if result.job_id == job_id and start_mol_id <= result.mol_id <= end_mol_id
]

async def upsert_result(self, result: Result) -> None:
try:
existing_result = await self.get_result_by_id(result.id)
self.results.update(existing_result, result)
except RecordNotFoundError:
existing_result = None

if existing_result:
self.results = [
existing_result if existing_result.id == result.id else result
for result in self.results
]
else:
self.results.append(result)

async def get_all_results_by_job_id(self, job_id: str) -> List[Result]:
return [result for result in self.results if result.job_id == job_id]
return [
result for result in self.results.get_items() if result.job_id == job_id
]

async def get_num_processed_entries_by_job_id(self, job_id: str) -> int:
return len(await self.get_all_results_by_job_id(job_id))

0 comments on commit ff4e9d9

Please sign in to comment.