Skip to content

Commit

Permalink
Creation methods in repository return the inserted database objects
Browse files Browse the repository at this point in the history
  • Loading branch information
shirte committed Dec 20, 2024
1 parent 96b1d67 commit 41a94d8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
14 changes: 9 additions & 5 deletions nerdd_backend/data/memory_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ async def get_module_changes(
async def get_all_modules(self) -> List[Module]:
return self.modules.get_items()

async def create_module(self, module: Module) -> None:
async def create_module(self, module: Module) -> Module:
assert module.id is not None
try:
await self.get_module_by_id(module.id)
raise RecordAlreadyExistsError(Module, module.id)
except RecordNotFoundError:
self.modules.append(module)
return module

async def get_module_by_id(self, id: str) -> Module:
try:
Expand All @@ -58,12 +59,14 @@ async def get_job_changes(
if (old is not None and old.id == job_id) or (new is not None and new.id == job_id):
yield (old, new)

async def create_job(self, job: Job) -> None:
async def create_job(self, job: Job) -> JobInternal:
try:
await self.get_job_by_id(job.id)
raise RecordAlreadyExistsError(Job, job.id)
except RecordNotFoundError:
self.jobs.append(JobInternal(**job.model_dump()))
result = JobInternal(**job.model_dump())
self.jobs.append(result)
return result

async def update_job(self, job: JobUpdate) -> JobInternal:
existing_job = await self.get_job_by_id(job.id)
Expand Down Expand Up @@ -94,12 +97,13 @@ async def delete_job_by_id(self, id: str) -> None:
#
# SOURCES
#
async def create_source(self, source: Source) -> None:
async def create_source(self, source: Source) -> Source:
try:
existing_source = await self.get_source_by_id(source.id)
await self.get_source_by_id(source.id)
raise RecordAlreadyExistsError(Source, source.id)
except RecordNotFoundError:
self.sources.append(source)
return source

async def get_source_by_id(self, id: str) -> Source:
try:
Expand Down
8 changes: 4 additions & 4 deletions nerdd_backend/data/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def get_module_by_id(self, module_id) -> Module:
pass

@abstractmethod
async def create_module(self, module: Module) -> None:
async def create_module(self, module: Module) -> Module:
pass

#
Expand All @@ -49,7 +49,7 @@ def get_job_changes(self, job_id: str) -> AsyncIterable[Tuple[Optional[Job], Opt
pass

@abstractmethod
async def create_job(self, job: Job) -> None:
async def create_job(self, job: Job) -> JobInternal:
pass

@abstractmethod
Expand All @@ -68,7 +68,7 @@ async def delete_job_by_id(self, job_id) -> None:
# SOURCES
#
@abstractmethod
async def create_source(self, source: Source) -> None:
async def create_source(self, source: Source) -> Source:
pass

@abstractmethod
Expand Down Expand Up @@ -96,7 +96,7 @@ async def get_results_by_job_id(
pass

@abstractmethod
async def create_result(self, result: Result) -> None:
async def create_result(self, result: Result) -> Result:
pass

@abstractmethod
Expand Down

0 comments on commit 41a94d8

Please sign in to comment.