Skip to content

Commit

Permalink
Update to sqlalchemy select API and FastAPI SessionDep
Browse files Browse the repository at this point in the history
  • Loading branch information
tomalrussell committed Dec 2, 2024
1 parent 358aa8f commit 3c1618e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 49 deletions.
9 changes: 0 additions & 9 deletions containers/backend/backend/app/dependencies.py

This file was deleted.

8 changes: 4 additions & 4 deletions containers/backend/backend/app/internal/attribute_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def add_damages_expected_value_query(
q = q.group_by(models.Feature.id)
value = functions.sum(value)

return q.add_column(value.label("value"))
return q.add_columns(value.label("value"))


# def add_damages_rp_value_query(fq: Query, dimesions: schemas.ReturnPeriodDamagesDimensions, field: str):
Expand Down Expand Up @@ -68,7 +68,7 @@ def add_adaptation_value_query(
else:
value = getattr(models.AdaptationCostBenefit, field)

return q.add_column(value.label("value"))
return q.add_columns(value.label("value"))


@dataclass
Expand Down Expand Up @@ -112,7 +112,7 @@ def parse_dimensions(field_group: str, dimensions: Json):
data_group_config = DATA_GROUP_CONFIGS.get(field_group)

if data_group_config is not None:
return data_group_config.dimensions_schema.parse_obj(dimensions)
return data_group_config.dimensions_schema.model_validate_json(dimensions)
else:
raise ValidationError(f"Invalid field group: {field_group}")

Expand All @@ -124,7 +124,7 @@ def parse_parameters(field_group: str, field: str, parameters: Json):
field_params_schema = data_group_config.field_parameters_schemas

if field_params_schema is not None and field in field_params_schema:
return field_params_schema[field].parse_obj(parameters)
return field_params_schema[field].model_validate_json(parameters)
else:
return None
else:
Expand Down
11 changes: 6 additions & 5 deletions containers/backend/backend/app/routers/attributes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Any
from fastapi import APIRouter, Body, Depends
from sqlalchemy.orm import Session
from sqlalchemy import select

from backend.app.dependencies import get_db
from backend.app.internal.attribute_access import (
add_value_query,
parse_dimensions,
parse_parameters,
)

from backend.db import models
from backend.db.database import SessionDep

from backend.app import schemas

Expand All @@ -21,20 +21,21 @@ def read_attributes(
layer: str,
field_group: str,
field: str,
session: SessionDep,
field_dimensions: schemas.DataDimensions = Depends(parse_dimensions),
field_params: schemas.DataParameters = Depends(parse_parameters),
ids: list[int] = Body(...),
db: Session = Depends(get_db),
):
base_query = (
db.query(models.Feature.id)
select(models.Feature.id)
.select_from(models.Feature)
.filter(models.Feature.layer == layer, models.Feature.id.in_(ids))
)
query = add_value_query(
base_query, field_group, field_dimensions, field, field_params
)
results = session.execute(query).all()

lookup = dict(query.all())
lookup = dict(results)

return {id: lookup.get(id, None) for id in ids}
15 changes: 7 additions & 8 deletions containers/backend/backend/app/routers/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,27 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination import Page, Params
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlalchemy import desc
from sqlalchemy import desc, select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from geoalchemy2 import functions

from backend.app import schemas
from backend.app.dependencies import get_db
from backend.app.internal.attribute_access import (
add_value_query,
parse_dimensions,
parse_parameters,
)
from backend.db import models
from backend.db.database import SessionDep


router = APIRouter(tags=["features"])


@router.get("/{feature_id}", response_model=schemas.FeatureOut)
def read_feature(feature_id: int, db: Session = Depends(get_db)):
def read_feature(feature_id: int, session: SessionDep):
try:
feature = db.query(models.Feature).filter(models.Feature.id == feature_id).one()
feature = session.get(models.Feature, feature_id)
except NoResultFound:
raise HTTPException(
status_code=404,
Expand Down Expand Up @@ -53,15 +52,15 @@ def get_layer_spec(
def read_sorted_features(
field_group: str,
field: str,
session: SessionDep,
field_dimensions: schemas.DataDimensions = Depends(parse_dimensions),
field_params: schemas.DataParameters = Depends(parse_parameters),
layer_spec: schemas.LayerSpec = Depends(get_layer_spec),
page_params: Params = Depends(),
db: Session = Depends(get_db),
):
filled_layer_spec = {k: v for k, v in layer_spec.dict().items() if v is not None}
base_query = (
db.query(
select(
models.Feature.id.label("id"),
models.Feature.string_id.label("string_id"),
models.Feature.layer.label("layer"),
Expand All @@ -76,4 +75,4 @@ def read_sorted_features(
base_query, field_group, field_dimensions, field, field_params
).order_by(desc("value"))

return paginate(q, page_params)
return paginate(session, q, page_params)
31 changes: 10 additions & 21 deletions containers/backend/backend/app/routers/tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
import json
import inspect

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, HTTPException
from fastapi.logger import logger
from starlette.responses import StreamingResponse
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from sqlalchemy import select
from terracotta.exceptions import DatasetNotFoundError
from terracotta import get_driver


from backend.app import schemas
from backend.app.dependencies import get_db
from backend.db import models
from backend.db.database import SessionDep
from backend.app.internal.helpers import build_driver_path, handle_exception
from backend.app.exceptions import (
SourceDBDoesNotExistException,
Expand Down Expand Up @@ -52,7 +52,7 @@ def _get_singleband_image(
::args database str DB under-which the requested data has been loaded
"""
from app.internal.tiles.singleband import singleband
from backend.app.internal.tiles.singleband import singleband

# Collect TC Driver path for terracotta db
driver_path = build_driver_path(database)
Expand Down Expand Up @@ -136,9 +136,7 @@ def _source_options(source_db: str) -> List[dict[str, str]]:


@router.get("/sources", response_model=List[schemas.TileSourceMeta])
def get_all_tile_source_meta(
db: Session = Depends(get_db),
) -> List[schemas.TileSourceMeta]:
def get_all_tile_source_meta(session: SessionDep) -> List[schemas.TileSourceMeta]:
"""
Retrieve metadata about all the tile sources available
"""
Expand All @@ -147,7 +145,7 @@ def get_all_tile_source_meta(
inspect.stack()[0][3],
)
try:
res = db.query(models.RasterTileSource).all()
res = session.execute(select(models.RasterTileSource)).all()
all_meta = []
for row in res:
meta = schemas.TileSourceMeta.model_validate(row)
Expand All @@ -160,8 +158,7 @@ def get_all_tile_source_meta(

@router.get("/sources/{source_id}", response_model=schemas.TileSourceMeta)
def get_tile_source_meta(
source_id: int,
db: Session = Depends(get_db),
source_id: int, session: SessionDep
) -> List[schemas.TileSourceMeta]:
"""
Retrieve metadata about a single tile source
Expand All @@ -171,11 +168,7 @@ def get_tile_source_meta(
inspect.stack()[0][3],
)
try:
res = (
db.query(models.RasterTileSource)
.filter(models.RasterTileSource.id == source_id)
.one()
)
res = session.get(models.RasterTileSource, source_id)
meta = schemas.TileSourceMeta.model_validate(res)
return meta
except NoResultFound:
Expand All @@ -188,7 +181,7 @@ def get_tile_source_meta(
@router.get("/sources/{source_id}/domains", response_model=schemas.TileSourceDomains)
def get_tile_source_domains(
source_id: int,
db: Session = Depends(get_db),
session: SessionDep,
) -> schemas.TileSourceDomains:
"""
Retrieve all combinations available for the source domain
Expand All @@ -198,11 +191,7 @@ def get_tile_source_domains(
inspect.stack()[0][3],
)
try:
res = (
db.query(models.RasterTileSource)
.filter(models.RasterTileSource.id == source_id)
.one()
)
res = session.get(models.RasterTileSource, source_id)
domains = _source_options(_tile_db_from_domain(res.domain))
meta = schemas.TileSourceDomains(domains=domains)
logger.debug(f"{source_id=} {res.domain=} {domains=} {meta=}")
Expand Down
13 changes: 11 additions & 2 deletions containers/backend/backend/db/database.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from typing import Annotated

from fastapi import Depends
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session


# pass empty connection string to use PG* environment variables (see https://www.postgresql.org/docs/current/libpq-envars.html)
engine = create_engine("postgresql+psycopg2://", future=True, pool_pre_ping=True)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

def get_session():
with Session(engine) as session:
yield session


SessionDep = Annotated[Session, Depends(get_session)]

Base = declarative_base()

0 comments on commit 3c1618e

Please sign in to comment.