Skip to content

Commit

Permalink
Feature/sqlite repo: setup for GeoLocationServiceSqlRepoDBTest (#68)
Browse files Browse the repository at this point in the history
Unstable, active and WIP: This is where things start taking shape for setting test coverage
`GeoLocationServiceSqlRepoDBTest` to implement issue: #26 ;

- major rewrite of test services for integration test for Sqlite Repo
DB. refactored code to use container to setup database using
start_test_database and asyncSetUp. WIP: test_fetch_facilities
- added sqlite models to support sqlite database. database.py enhances
app level infra setup using DI using rodi's Container that sets up
database and session
  • Loading branch information
codecakes authored Sep 9, 2024
1 parent 769c0e1 commit e2e7e62
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pip-install:
pip install --prefer-binary --use-pep517 --check-build-dependencies .[dev]

test:
APP_ENV=test APP_DB_ENGINE_URL="sqlite+aiosqlite://" pytest -s xcov19/tests/ -m "not integration"
APP_ENV=test APP_DB_ENGINE_URL="sqlite+aiosqlite://" pytest -s xcov19/tests/ -m "not slow and not integration and not api"

test-integration:
APP_ENV=test APP_DB_ENGINE_URL="sqlite+aiosqlite://" pytest -s xcov19/tests/ -m "integration"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ asyncio_mode = "auto"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"integration: marks tests as integration tests",
"api: mark api tests",
"unit: marks tests as unit tests",
# Add more markers as needed
]
Expand Down
41 changes: 20 additions & 21 deletions xcov19/app/database.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import sys
from rodi import Container
from sqlmodel import SQLModel

from xcov19.infra.models import SQLModel
from sqlmodel import text
from xcov19.app.settings import Settings
from sqlmodel.ext.asyncio.session import AsyncSession as AsyncSessionWrapper
from sqlalchemy.ext.asyncio import (
create_async_engine,
AsyncEngine,
AsyncSession,
async_sessionmaker,
)

Expand Down Expand Up @@ -38,36 +39,33 @@ class SessionFactory:
def __init__(self, engine: AsyncEngine):
self._engine = engine

def __call__(self) -> async_sessionmaker[AsyncSession]:
def __call__(self) -> async_sessionmaker[AsyncSessionWrapper]:
return async_sessionmaker(
self._engine, class_=AsyncSession, expire_on_commit=False
self._engine, class_=AsyncSessionWrapper, expire_on_commit=False
)


async def setup_database(engine: AsyncEngine) -> None:
"""Sets up tables for database."""
async with engine.begin() as conn:
# see: https://sqlmodel.tiangolo.com/tutorial/relationship-attributes/cascade-delete-relationships/#enable-foreign-key-support-in-sqlite
await conn.execute(text("PRAGMA foreign_keys=ON"))
await conn.run_sync(SQLModel.metadata.create_all)
await conn.commit()
db_logger.info("===== Database tables setup. =====")


async def create_async_session(
AsyncSessionFactory: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[AsyncSession, None]:
"""Create an asynchronous database session."""
async with AsyncSessionFactory() as session:
try:
yield session
finally:
await session.close()


async def start_db_session(container: Container):
@asynccontextmanager
async def start_db_session(
container: Container,
) -> AsyncGenerator[AsyncSessionWrapper, None]:
"""Starts a new database session given SessionFactory."""
# add LocalAsyncSession
local_async_session = create_async_session(
container.resolve(async_sessionmaker[AsyncSession])
async_session_factory: async_sessionmaker[AsyncSessionWrapper] = container.resolve(
async_sessionmaker[AsyncSessionWrapper]
)
container.add_instance(local_async_session, AsyncSession)
async with async_session_factory() as local_async_session:
yield local_async_session


def configure_database_session(container: Container, settings: Settings) -> Container:
Expand All @@ -82,8 +80,9 @@ def configure_database_session(container: Container, settings: Settings) -> Cont
container.add_instance(engine, AsyncEngine)

# add sessionmaker
session_factory = SessionFactory(engine)
container.add_singleton_by_factory(
SessionFactory(engine), async_sessionmaker[AsyncSession]
session_factory, async_sessionmaker[AsyncSessionWrapper]
)

db_logger.info("====== Database session configured. ======")
Expand Down
2 changes: 0 additions & 2 deletions xcov19/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from xcov19.app.database import (
configure_database_session,
setup_database,
start_db_session,
)
from xcov19.app.auth import configure_authentication
from xcov19.app.controllers import controller_router
Expand Down Expand Up @@ -46,6 +45,5 @@ async def on_start():
container: ContainerProtocol = app.services
if not isinstance(container, Container):
raise ValueError("Container is not a valid container")
await start_db_session(container)
engine = container.resolve(AsyncEngine)
await setup_database(engine)
Empty file added xcov19/infra/__init__.py
Empty file.
138 changes: 138 additions & 0 deletions xcov19/infra/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Database Models and Delete Behavior Design Principles
1. Query-Patient-Location Relationship:
- Every Query must have both a Patient and a Location associated with it.
- A Patient can have multiple Queries.
- A Location can be associated with multiple Queries.
2. Delete Restrictions:
- Patient and Location records cannot be deleted if there are any Queries referencing them.
- This is enforced by the "RESTRICT" ondelete option in the Query model's foreign keys.
3. Orphan Deletion:
- A Patient or Location should be deleted only when there are no more Queries referencing it.
- This is handled by custom event listeners that check for remaining Queries after a Query deletion.
4. Cascading Behavior:
- There is no automatic cascading delete from Patient or Location to Query.
- Queries must be explicitly deleted before their associated Patient or Location can be removed.
5. Transaction Handling:
- Delete operations and subsequent orphan checks should occur within the same transaction.
- Event listeners use the existing database connection to ensure consistency with the main transaction.
6. Error Handling:
- Errors during the orphan deletion process should not silently fail.
- Exceptions in event listeners are logged and re-raised to ensure proper transaction rollback.
7. Data Integrity:
- Database-level constraints (foreign keys, unique constraints) are used in conjunction with SQLAlchemy model definitions to ensure data integrity.
These principles aim to maintain referential integrity while allowing for the cleanup of orphaned Patient and Location records when appropriate.
"""

from __future__ import annotations

from typing import List
from sqlmodel import SQLModel, Field, Relationship
from sqlalchemy import Column, Text, Float, Index
from sqlalchemy.orm import relationship, Mapped
import uuid
from sqlalchemy.dialects.sqlite import TEXT


class Patient(SQLModel, table=True):
patient_id: str = Field(
sa_column=Column(
TEXT, unique=True, primary_key=True, default=str(uuid.uuid4())
),
allow_mutation=False,
)
queries: Mapped[List["Query"]] = Relationship(
# back_populates="patient",
passive_deletes="all",
cascade_delete=True,
sa_relationship=relationship(back_populates="patient"),
)


class Query(SQLModel, table=True):
"""Every Query must have both a Patient and a Location."""

query_id: str = Field(
sa_column=Column(
TEXT, unique=True, primary_key=True, default=str(uuid.uuid4())
),
allow_mutation=False,
)
query: str = Field(allow_mutation=False, sa_column=Column(Text))
# Restrict deleting Patient record when there is atleast 1 query referencing it
patient_id: str = Field(foreign_key="patient.patient_id", ondelete="RESTRICT")
# Restrict deleting Location record when there is atleast 1 query referencing it
location_id: str = Field(foreign_key="location.location_id", ondelete="RESTRICT")
location: Location = Relationship(back_populates="queries")
patient: Patient = Relationship(back_populates="queries")


class Location(SQLModel, table=True):
__table_args__ = (
Index("ix_location_composite_lat_lng", "latitude", "longitude", unique=True),
)
location_id: str = Field(
sa_column=Column(
TEXT, unique=True, primary_key=True, default=str(uuid.uuid4())
),
allow_mutation=False,
)
latitude: float = Field(sa_column=Column(Float))
longitude: float = Field(sa_column=Column(Float))
queries: Mapped[List["Query"]] = Relationship(
# back_populates="location",
cascade_delete=True,
passive_deletes=True,
sa_relationship=relationship(back_populates="location"),
)


# TODO: Define Provider SQL model fields
# class Provider(SQLModel, table=True):
# # TODO: Compare with Github issue, domain model and noccodb
# ...


# TODO: Add Model events for database ops during testing
# @event.listens_for(Query, "after_delete")
# def delete_dangling_location(mapper: Mapper, connection: Engine, target: Query):
# """Deletes orphan Location when no related queries exist."""
# local_session = sessionmaker(connection)
# with local_session() as session:
# stmt = (
# select(func.count())
# .select_from(Query)
# .where(Query.location_id == target.location_id)
# )
# if (
# num_queries := session.execute(stmt).scalar_one_or_none()
# ) and num_queries <= 1:
# location: Location = session.get(Location, target.location_id)
# session.delete(location)
# session.flush()


# @event.listens_for(Query, "after_delete")
# def delete_dangling_patient(mapper: Mapper, connection: Engine, target: Query):
# """Deletes orphan Patient records when no related queries exist."""
# local_session = sessionmaker(connection)
# with local_session() as session:
# stmt = (
# select(func.count())
# .select_from(Query)
# .where(Query.patient_id == target.patient_id)
# )
# if (
# num_queries := session.execute(stmt).scalar_one_or_none()
# ) and num_queries <= 1:
# patient: Patient = session.get(Patient, target.patient_id)
# session.delete(patient)
# session.flush()
Empty file added xcov19/tests/data/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions xcov19/tests/data/seed_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Dummy data to seed to database models.
Mapped to SQLModel.
dummy GeoLocation:
lat=0
lng=0
cust_id=test_cust_id
query_id=test_query_id
"""

from sqlalchemy import ScalarResult
from sqlmodel import select
from xcov19.infra.models import Patient, Query, Location
from sqlmodel.ext.asyncio.session import AsyncSession as AsyncSessionWrapper


async def seed_data(session: AsyncSessionWrapper):
"""
Now you can do:
res = await self._session.exec(select(Query))
query = res.first()
print("query", query)
res = await self._session.exec(select(Patient).where(Patient.queries.any(Query.query_id == query.query_id)))
print("patient", res.first())
res = await self._session.exec(select(Location).where(Location.queries.any(Query.query_id == query.query_id)))
print("location", res.first())
"""
query = Query(
query="""
Runny nose and high fever suddenly lasting for few hours.
Started yesterday.
"""
) # type: ignore

patient = Patient(queries=[query]) # type: ignore

patient_location = Location(latitude=0, longitude=0, queries=[query]) # type: ignore
session.add_all([patient_location, patient])
await session.commit()
query_result: ScalarResult = await session.exec(select(Query))
if not query_result.first():
raise RuntimeError("Database seeding failed")
18 changes: 16 additions & 2 deletions xcov19/tests/start_server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from collections.abc import AsyncGenerator
from xcov19.app.main import app
from blacksheep import Application
from contextlib import asynccontextmanager
from rodi import Container, ContainerProtocol
from xcov19.app.database import configure_database_session, setup_database
from xcov19.app.settings import load_settings
from sqlalchemy.ext.asyncio import AsyncEngine


async def start_server() -> AsyncGenerator[Application, None]:
@asynccontextmanager
async def start_server(app: Application) -> AsyncGenerator[Application, None]:
"""Start a test server for automated testing."""
try:
await app.start()
yield app
finally:
if app.started:
await app.stop()


async def start_test_database(container: ContainerProtocol) -> None:
"""Database setup for integration tests."""
if not isinstance(container, Container):
raise RuntimeError("container not of type Container.")
configure_database_session(container, load_settings())
engine = container.resolve(AsyncEngine)
await setup_database(engine)
38 changes: 24 additions & 14 deletions xcov19/tests/test_services.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections.abc import Callable
from contextlib import AsyncExitStack
from typing import List
import pytest
import unittest

from rodi import ContainerProtocol
from xcov19.tests.start_server import start_server
from rodi import Container, ContainerProtocol
from xcov19.app.database import start_db_session
from xcov19.tests.data.seed_db import seed_data
from xcov19.tests.start_server import start_test_database
from xcov19.domain.models.provider import (
Contact,
FacilityEstablishment,
Expand All @@ -24,7 +27,8 @@

import random

from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel.ext.asyncio.session import AsyncSession as AsyncSessionWrapper


RANDOM_SEED = random.seed(1)

Expand Down Expand Up @@ -185,7 +189,7 @@ async def test_fetch_facilities_no_results(self):
self.assertIsNone(result)


@pytest.mark.skip(reason="WIP")
# @pytest.mark.skip(reason="WIP")
@pytest.mark.integration
@pytest.mark.usefixtures("dummy_reverse_geo_lookup_svc", "dummy_geolocation_query_json")
class GeoLocationServiceSqlRepoDBTest(unittest.IsolatedAsyncioTestCase):
Expand All @@ -198,23 +202,29 @@ class GeoLocationServiceSqlRepoDBTest(unittest.IsolatedAsyncioTestCase):
"""

async def asyncSetUp(self) -> None:
app = await anext(start_server())
self._container: ContainerProtocol = app.services
self._seed_db(self._container.resolve(AsyncSession))
self._stack = AsyncExitStack()
container: ContainerProtocol = Container()
await start_test_database(container)
self._session = await self._stack.enter_async_context(
start_db_session(container)
)
if not isinstance(self._session, AsyncSessionWrapper):
raise RuntimeError(f"{self._session} is not a AsyncSessionWrapper value.")
await seed_data(self._session)
await super().asyncSetUp()

def _seed_db(self, session: AsyncSession) -> None:
# TODO: add data to sqlite tables based on dummy_geolocation_query_json
# and add providers data.
...
async def asyncTearDown(self) -> None:
print("async closing test server db session closing.")
await self._session.commit()
await self._stack.aclose()
print("async test server closing.")
await super().asyncTearDown()

def _patient_query_lookup_svc_using_repo(
self, address: Address, query: LocationQueryJSON
) -> Callable[[Address, LocationQueryJSON], List[FacilitiesResult]]: ...

async def test_fetch_facilities(
self, dummy_reverse_geo_lookup_svc, dummy_geolocation_query_json
):
async def test_fetch_facilities(self):
# TODO Implement test_fetch_facilities like this:
# providers = await GeolocationQueryService.fetch_facilities(
# dummy_reverse_geo_lookup_svc,
Expand Down

0 comments on commit e2e7e62

Please sign in to comment.