Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to sqlachemy 2 #172

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions openstef_dbc/data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import requests
import sqlalchemy

from sqlalchemy import text, bindparam
from openstef_dbc import Singleton
from openstef_dbc.ktp_api import KtpApi
from openstef_dbc.log import logging
Expand Down Expand Up @@ -266,7 +266,8 @@ def exec_sql_query(self, query: str, params: dict = None):
with self.sql_engine.connect() as connection:
if params is None:
params = {}
cursor = connection.execute(query, **params)
cursor = connection.execute(text(query).bindparams(**params))
# cursor = connection.execute(text(query).bindparams([bindparam(p, expanding=True) for p in params]))
if cursor.cursor is not None:
return pd.DataFrame(cursor.fetchall())
except sqlalchemy.exc.OperationalError as e:
Expand All @@ -288,8 +289,8 @@ def exec_sql_query(self, query: str, params: dict = None):
def exec_sql_write(self, statement: str, params: dict = None) -> None:
try:
with self.sql_engine.connect() as connection:
response = connection.execute(statement, params=params)

response = connection.execute(text(statement).bindparams(**params))
connection.commit()
self.logger.info(
f"Added {response.rowcount} new systems to the systems table in the {self.sql_db_type} database"
)
Expand Down
2 changes: 1 addition & 1 deletion openstef_dbc/services/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def get_power_curve(cls, turbine_type: str) -> dict:
""" "This function retrieves the power curve coefficients from the genericpowercurves table,
using the turbine type as input."""
bind_params = {"turbine_type": turbine_type}
query = "SELECT * FROM genericpowercurves WHERE name = %(turbine_type)s"
query = "SELECT * FROM genericpowercurves WHERE name = :turbine_type"

result = _DataInterface.get_instance().exec_sql_query(query, bind_params)

Expand Down
4 changes: 2 additions & 2 deletions openstef_dbc/services/prediction_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def get_pids_for_api_key(self, api_key: str) -> list[int]:
LEFT JOIN `customers` as cu ON cak.cid = cu.id
LEFT JOIN `customers_predictions` as cp ON cu.id = cp.customer_id
LEFT JOIN `predictions` as p ON p.id = cp.prediction_id
WHERE cak.api_key = %(apiKey)s
WHERE cak.apiKey = :apiKey
"""
result = _DataInterface.get_instance().exec_sql_query(query, bind_params)
if isinstance(result, pd.DataFrame) and result.empty:
Expand Down Expand Up @@ -264,7 +264,7 @@ def get_ean_for_pid(self, pid: int) -> list[str]:
SELECT
p.ean
FROM `predictions` as p
WHERE p.id = %(pid)s
WHERE p.id = :pid
"""
result = _DataInterface.get_instance().exec_sql_query(query, bind_params)
if isinstance(result, pd.DataFrame) and result.empty:
Expand Down
8 changes: 4 additions & 4 deletions openstef_dbc/services/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,21 @@ def get_energy_split_coefs(self, pj: dict, mean: bool = False) -> dict:
bind_params = {"pid": pj["id"], "dstart": start_date.isoformat()}
query = (
"SELECT ec.coef_name, AVG(ec.coef_value) FROM energy_split_coefs as ec "
"WHERE ec.pid = %(pid)s AND ec.created > %(dstart)s GROUP BY ec.coef_name "
"WHERE ec.pid = :pid AND ec.created > :dstart GROUP BY ec.coef_name "
)
# Retrieve latest coefficients otherwise
else:
bind_params = {"pid": pj["id"]}
query = (
"SELECT ec.coef_name,ec.coef_value FROM energy_split_coefs as ec WHERE ec.pid = %(pid)s "
"SELECT ec.coef_name,ec.coef_value FROM energy_split_coefs as ec WHERE ec.pid = :pid "
"AND ec.created = (SELECT max(energy_split_coefs.created) from energy_split_coefs "
"WHERE energy_split_coefs.pid = %(pid)s)"
"WHERE energy_split_coefs.pid = :pid)"
)
# Execute query
result = _DataInterface.get_instance().exec_sql_query(query, bind_params)

# Make output dict
if result is not None:
if not result.empty:
result = result.set_index("coef_name")
if mean:
result = result.to_dict()["AVG(ec.coef_value)"]
Expand Down
20 changes: 10 additions & 10 deletions openstef_dbc/services/systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def get_systems_near_location(
"radius": str(radius),
}
query = """
SELECT `sid`, `lat`, `lon`,`region`, ( 6371 * acos( cos( radians(%(lat)s) ) \
* cos( radians( lat ) ) * cos( radians( lon ) - radians(%(lon)s) ) + sin( radians(%(lat)s) ) \
SELECT `sid`, `lat`, `lon`,`region`, ( 6371 * acos( cos( radians(:lat) ) \
* cos( radians( lat ) ) * cos( radians( lon ) - radians(:lon) ) + sin( radians(:lat) ) \
* sin( radians( lat ) ) ) ) AS `distance` \
FROM `systems`
WHERE `qual` > '%(quality)s'
WHERE `qual` > ':quality'
"""

# Extend query
if freq is not None:
bind_params["freq"] = str(freq)
query += """ AND `freq` <= %(freq)s"""
query += """ AND `freq` <= :freq"""
if lag_systems is not None:
query += """ AND `lagSystems` <= %(quality)s"""
query += """ AND `lagSystems` <= :quality"""

# Limit radius to given input radius
query += """ HAVING `distance` < %(radius)s ORDER BY `distance`;"""
query += """ HAVING `distance` < :radius ORDER BY `distance`;"""

result = _DataInterface.get_instance().exec_sql_query(query, bind_params)
return result
Expand All @@ -82,7 +82,7 @@ def get_systems_by_pid(
SELECT * from systems
INNER JOIN predictions_systems
ON predictions_systems.system_id=systems.sid
WHERE predictions_systems.prediction_id=%(pid)s
WHERE predictions_systems.prediction_id=:pid
"""

systems = _DataInterface.get_instance().exec_sql_query(
Expand Down Expand Up @@ -111,12 +111,12 @@ def get_random_pv_systems(

if limit is not None:
bind_params["limit"] = limit
limit_query = f"LIMIT %(limit)s"
limit_query = f"LIMIT :limit"

query = f"""
SELECT sid, qual, freq, lag
FROM systems
WHERE left(sid, 3) = 'pv_' AND autoupdate = %(autoupdate)s
WHERE left(sid, 3) = 'pv_' AND autoupdate = :autoupdate
ORDER BY RAND() {limit_query}
"""

Expand All @@ -138,7 +138,7 @@ def get_api_key_for_system(self, sid: str) -> str:
sa.apiKey
FROM `systems` as s
LEFT JOIN `systemsApiKeys` as sa ON s.measurements_customer_api_key_id = sa.id
WHERE s.sid = %(system)s;
WHERE s.sid = :system;
"""

result = _DataInterface.get_instance().exec_sql_query(query, bind_params)
Expand Down
4 changes: 2 additions & 2 deletions openstef_dbc/services/weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_weather_forecast_locations(
query = """
SELECT input_city as city, lat, lon, country
FROM weatherforecastlocations
WHERE country = %(country)s AND active = %(active)s
WHERE country = :country AND active = :active
"""
result = _DataInterface.get_instance().exec_sql_query(query, bind_params)

Expand Down Expand Up @@ -136,7 +136,7 @@ def _get_coordinates_of_location(self, location_name: str) -> Tuple[float, float

# Query corresponding (lat, lon) from SQL database
binding_params = {"city": location_name}
query = "SELECT lat, lon from NameToLatLon where regionInput = %(city)s"
query = "SELECT lat, lon from NameToLatLon where regionInput = :city"
location = _DataInterface.get_instance().exec_sql_query(query, binding_params)

# If not found
Expand Down
5 changes: 3 additions & 2 deletions openstef_dbc/services/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ def __init__(self):

def write_location(self, location_name: str, location: Tuple[float, float]) -> None:
bind_params = {
"table_name": "NameToLatLon",
"loc": location_name,
"lat": location[0],
"lon": location[1],
}

statement = "INSERT INTO %(table_name)s (regionInput, lat,lon) VALUES (%(loc)s, %(lat)s, %(lon)s)"
statement = (
"INSERT INTO NameToLatLon (regionInput, lat,lon) VALUES (:loc, :lat, :lon)"
)

_DataInterface.get_instance().exec_sql_write(statement, params=bind_params)

Expand Down
128 changes: 128 additions & 0 deletions tests/integration/test_database_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# SPDX-FileCopyrightText: 2017-2022 Contributors to the OpenSTEF project <korte.termijn.prognoses@alliander.com>
#
# SPDX-License-Identifier: MPL-2.0

from datetime import datetime
import pytz

from pandas import Timestamp
import pandas as pd
import numpy as np
import unittest

from openstef_dbc.data_interface import _DataInterface
from openstef_dbc.database import DataBase
from openstef.data_classes.prediction_job import PredictionJobDataClass
from openstef_dbc.services.prediction_job import PredictionJobRetriever
from openstef_dbc.services.systems import Systems
from openstef_dbc.services.model_input import ModelInput
from openstef_dbc.services.splitting import Splitting
from openstef_dbc.services.weather import Weather
from openstef_dbc.services.write import Write
from tests.integration.mock_influx_db_admin import MockInfluxDBAdmin

from tests.integration.settings import Settings

UTC = pytz.timezone("UTC")


class TestDataBaseConnexion(unittest.TestCase):
def setUp(self) -> None:
# Initialize settings
config = Settings()
self.di = _DataInterface(config)

# Initialize database object
self.database = DataBase(config)

def test_sql_db_available(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change

assert self.di.check_sql_available() == True

def test_get_prediction_jobs(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change

pj_retriever = PredictionJobRetriever()

response = pj_retriever.get_prediction_jobs()
assert isinstance(response, list)
assert isinstance(response[0], PredictionJobDataClass)

def test_get_pids_for_api_key(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change

pj_retriever = PredictionJobRetriever()

response = pj_retriever.get_pids_for_api_key("random_api_key")
assert isinstance(response, list)

def test_get_ean_for_pid(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change

pj_retriever = PredictionJobRetriever()

response = pj_retriever.get_ean_for_pid(1)
assert isinstance(response, list)

def test_add_quantiles_to_prediction_jobs(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change

pj_retriever = PredictionJobRetriever()
pjs = pj_retriever.get_prediction_jobs()

response = pj_retriever._add_quantiles_to_prediction_jobs(pjs)[0]
assert isinstance(response, PredictionJobDataClass)
assert hasattr(response, "quantiles")

def test_get_systems_near_location(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change

system = Systems()
response = system.get_systems_near_location(location=[0.0, 0.0])
assert isinstance(response, pd.DataFrame)

def test_get_systems_by_pid(self):
system = Systems()
response = system.get_systems_by_pid(pid=1)
assert isinstance(response, pd.DataFrame)

def test_get_random_pv_systems(self):
system = Systems()
response = system.get_random_pv_systems()
assert isinstance(response, pd.DataFrame)

def test_get_api_key_for_system(self):
system = Systems()
response = system.get_api_key_for_system(sid="1")
assert isinstance(response, str)

def test_get_power_curve(self):
modelinput = ModelInput()
response = modelinput.get_power_curve(turbine_type="Enercon E101")
assert isinstance(response, dict)
for key in [
"name",
"cut_in",
"cut_off",
"kind",
"manufacturer",
"peak_capacity",
"rated_power",
"slope_center",
"steepness",
]:
assert key in response

def test_get_energy_split_coefs(self):
splitting = Splitting()
pj_retriever = PredictionJobRetriever()
pj = pj_retriever.get_prediction_jobs()[0]
response = splitting.get_energy_split_coefs(pj=pj)

assert isinstance(response, dict)

def test_get_weather_forecast_locations(self):
weather = Weather()
response = weather.get_weather_forecast_locations()

assert isinstance(response, list)

def test_get_coordinates_of_location(self):
weather = Weather()
response = weather._get_coordinates_of_location(location_name="Leeuwarden")
assert isinstance(response, tuple)
Loading