From 141df0624b016db61d1361201c34fe4aa8afe3d6 Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Thu, 5 Sep 2024 08:42:30 +0100 Subject: [PATCH] Issue/add ml model (#152) * #147 add ml model * add migration for datamodel * add model filter + tests * PR comment, rename function * fix * add to test * read models names * fix tests * add test --- pvsite_datamodel/read/model.py | 46 ++++++++++++++++++++++- tests/read/test_read_models.py | 67 +++++++++++++++++++++++++++++++++- 2 files changed, 110 insertions(+), 3 deletions(-) diff --git a/pvsite_datamodel/read/model.py b/pvsite_datamodel/read/model.py index 0e1f6cb..9c8b118 100644 --- a/pvsite_datamodel/read/model.py +++ b/pvsite_datamodel/read/model.py @@ -1,10 +1,11 @@ """ Read functions for getting ML models. """ import logging -from typing import Optional +from datetime import datetime +from typing import List, Optional from sqlalchemy.orm import Session -from pvsite_datamodel.sqlmodels import MLModelSQL +from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, MLModelSQL, SiteSQL logger = logging.getLogger(__name__) @@ -50,3 +51,44 @@ def get_or_create_model(session: Session, name: str, version: Optional[str] = No model = models[0] return model + + +def get_models( + session: Session, + start_datetime: Optional[datetime] = None, + end_datetime: Optional[datetime] = None, + site_uuid: Optional[str] = None, +) -> List[MLModelSQL]: + """ + Get model names from forecast values. + + They are distinct on model name + By adding start and end datetimes, we only look at forecast values in that time range. + By adding site_uuid, we only look at forecast values for that site. + + :param session: database session + :param start_datetime: optional filter on start datetime + :param end_datetime: optional filter on end datetime + :param site_uuid: optional filter on site uuid + :return: list of model names + """ + query = session.query(MLModelSQL) + + query = query.distinct(MLModelSQL.name) + + if (start_datetime is not None) or (end_datetime is not None) or (site_uuid is not None): + query = query.join(ForecastValueSQL) + + if start_datetime is not None: + query = query.where(ForecastValueSQL.start_utc > start_datetime) + + if end_datetime is not None: + query = query.where(ForecastValueSQL.start_utc < end_datetime) + + if site_uuid is not None: + query = query.join(ForecastSQL) + query = query.join(SiteSQL) + query = query.where(SiteSQL.site_uuid == site_uuid) + + models: [MLModelSQL] = query.all() + return models diff --git a/tests/read/test_read_models.py b/tests/read/test_read_models.py index ec05d73..455cff5 100644 --- a/tests/read/test_read_models.py +++ b/tests/read/test_read_models.py @@ -1,7 +1,12 @@ """ test get models""" +import datetime as dt +from uuid import uuid4 -from pvsite_datamodel.read.model import get_or_create_model +import pandas as pd + +from pvsite_datamodel.read.model import get_models, get_or_create_model from pvsite_datamodel.sqlmodels import MLModelSQL +from pvsite_datamodel.write.forecast import insert_forecast_values def test_get_model(db_session): @@ -15,3 +20,63 @@ def test_get_model(db_session): _ = get_or_create_model(session=db_session, name="test_name", version="9.9.10") assert len(db_session.query(MLModelSQL).all()) == 2 + + +def test_get_models(db_session): + _ = get_or_create_model(session=db_session, name="test_name", version="9.9.9") + _ = get_or_create_model(session=db_session, name="test_name", version="9.9.10") + + models = get_models(session=db_session) + + assert len(models) == 1 + + +def test_get_models_with_datetimes(db_session, forecast_valid_input): + model = get_or_create_model(session=db_session, name="test_name", version="9.9.10") + + forecast_valid_meta_input, forecast_valid_values_input = forecast_valid_input + + df = pd.DataFrame(forecast_valid_values_input) + + insert_forecast_values( + session=db_session, + forecast_meta=forecast_valid_meta_input, + forecast_values_df=df, + ml_model_name=model.name, + ml_model_version=model.version, + ) + + now = dt.datetime.now(dt.timezone.utc) + models = get_models(session=db_session, start_datetime=now) + assert len(models) == 1 + + models = get_models(session=db_session, start_datetime=now + dt.timedelta(days=1)) + assert len(models) == 0 + + +def test_get_models_with_datetimes_with_sites(db_session, forecast_valid_input): + model = get_or_create_model(session=db_session, name="test_name", version="9.9.10") + + forecast_valid_meta_input, forecast_valid_values_input = forecast_valid_input + + df = pd.DataFrame(forecast_valid_values_input) + + insert_forecast_values( + session=db_session, + forecast_meta=forecast_valid_meta_input, + forecast_values_df=df, + ml_model_name=model.name, + ml_model_version=model.version, + ) + + models = get_models( + session=db_session, + start_datetime=dt.datetime.now(dt.timezone.utc), + site_uuid=forecast_valid_meta_input["site_uuid"], + ) + assert len(models) == 1 + + models = get_models( + session=db_session, start_datetime=dt.datetime.now(dt.timezone.utc), site_uuid=str(uuid4()) + ) + assert len(models) == 0