Skip to content

Commit

Permalink
Issue/add ml model (#152)
Browse files Browse the repository at this point in the history
* #147 add ml model

* add migration for datamodel

* add model filter + tests

* PR comment, rename function

* fix

* add to test

* read models names

* fix tests

* add test
  • Loading branch information
peterdudfield authored Sep 5, 2024
1 parent adf8892 commit 141df06
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 3 deletions.
46 changes: 44 additions & 2 deletions pvsite_datamodel/read/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
""" Read functions for getting ML models. """
import logging
from typing import Optional
from datetime import datetime
from typing import List, Optional

from sqlalchemy.orm import Session

from pvsite_datamodel.sqlmodels import MLModelSQL
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, MLModelSQL, SiteSQL

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,3 +51,44 @@ def get_or_create_model(session: Session, name: str, version: Optional[str] = No
model = models[0]

return model


def get_models(
session: Session,
start_datetime: Optional[datetime] = None,
end_datetime: Optional[datetime] = None,
site_uuid: Optional[str] = None,
) -> List[MLModelSQL]:
"""
Get model names from forecast values.
They are distinct on model name
By adding start and end datetimes, we only look at forecast values in that time range.
By adding site_uuid, we only look at forecast values for that site.
:param session: database session
:param start_datetime: optional filter on start datetime
:param end_datetime: optional filter on end datetime
:param site_uuid: optional filter on site uuid
:return: list of model names
"""
query = session.query(MLModelSQL)

query = query.distinct(MLModelSQL.name)

if (start_datetime is not None) or (end_datetime is not None) or (site_uuid is not None):
query = query.join(ForecastValueSQL)

if start_datetime is not None:
query = query.where(ForecastValueSQL.start_utc > start_datetime)

if end_datetime is not None:
query = query.where(ForecastValueSQL.start_utc < end_datetime)

if site_uuid is not None:
query = query.join(ForecastSQL)
query = query.join(SiteSQL)
query = query.where(SiteSQL.site_uuid == site_uuid)

models: [MLModelSQL] = query.all()
return models
67 changes: 66 additions & 1 deletion tests/read/test_read_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
""" test get models"""
import datetime as dt
from uuid import uuid4

from pvsite_datamodel.read.model import get_or_create_model
import pandas as pd

from pvsite_datamodel.read.model import get_models, get_or_create_model
from pvsite_datamodel.sqlmodels import MLModelSQL
from pvsite_datamodel.write.forecast import insert_forecast_values


def test_get_model(db_session):
Expand All @@ -15,3 +20,63 @@ def test_get_model(db_session):

_ = get_or_create_model(session=db_session, name="test_name", version="9.9.10")
assert len(db_session.query(MLModelSQL).all()) == 2


def test_get_models(db_session):
_ = get_or_create_model(session=db_session, name="test_name", version="9.9.9")
_ = get_or_create_model(session=db_session, name="test_name", version="9.9.10")

models = get_models(session=db_session)

assert len(models) == 1


def test_get_models_with_datetimes(db_session, forecast_valid_input):
model = get_or_create_model(session=db_session, name="test_name", version="9.9.10")

forecast_valid_meta_input, forecast_valid_values_input = forecast_valid_input

df = pd.DataFrame(forecast_valid_values_input)

insert_forecast_values(
session=db_session,
forecast_meta=forecast_valid_meta_input,
forecast_values_df=df,
ml_model_name=model.name,
ml_model_version=model.version,
)

now = dt.datetime.now(dt.timezone.utc)
models = get_models(session=db_session, start_datetime=now)
assert len(models) == 1

models = get_models(session=db_session, start_datetime=now + dt.timedelta(days=1))
assert len(models) == 0


def test_get_models_with_datetimes_with_sites(db_session, forecast_valid_input):
model = get_or_create_model(session=db_session, name="test_name", version="9.9.10")

forecast_valid_meta_input, forecast_valid_values_input = forecast_valid_input

df = pd.DataFrame(forecast_valid_values_input)

insert_forecast_values(
session=db_session,
forecast_meta=forecast_valid_meta_input,
forecast_values_df=df,
ml_model_name=model.name,
ml_model_version=model.version,
)

models = get_models(
session=db_session,
start_datetime=dt.datetime.now(dt.timezone.utc),
site_uuid=forecast_valid_meta_input["site_uuid"],
)
assert len(models) == 1

models = get_models(
session=db_session, start_datetime=dt.datetime.now(dt.timezone.utc), site_uuid=str(uuid4())
)
assert len(models) == 0

0 comments on commit 141df06

Please sign in to comment.