diff --git a/ensysmod/api/api.py b/ensysmod/api/api.py index 86ee061..20febbe 100644 --- a/ensysmod/api/api.py +++ b/ensysmod/api/api.py @@ -10,6 +10,8 @@ energy_sinks, energy_sources, energy_storages, + energy_transmission_distances, + energy_transmission_losses, energy_transmissions, regions, ts_capacity_fix, @@ -31,6 +33,8 @@ api_router.include_router(energy_sources.router, prefix="/sources", tags=["Energy Sources"]) api_router.include_router(energy_storages.router, prefix="/storages", tags=["Energy Storages"]) api_router.include_router(energy_transmissions.router, prefix="/transmissions", tags=["Energy Transmissions"]) +api_router.include_router(energy_transmission_distances.router, prefix="/distances", tags=["Energy Transmission Distances"]) +api_router.include_router(energy_transmission_losses.router, prefix="/losses", tags=["Energy Transmission Losses"]) api_router.include_router(energy_models.router, prefix="/models", tags=["Energy Models"]) api_router.include_router(ts_capacity_fix.router, prefix="/fix-capacities", tags=["Fix Capacities"]) diff --git a/ensysmod/api/endpoints/energy_transmission_distances.py b/ensysmod/api/endpoints/energy_transmission_distances.py new file mode 100644 index 0000000..9d9ea26 --- /dev/null +++ b/ensysmod/api/endpoints/energy_transmission_distances.py @@ -0,0 +1,148 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from ensysmod import crud, model +from ensysmod.api import deps, permissions +from ensysmod.schemas import ( + EnergyTransmissionDistance, + EnergyTransmissionDistanceCreate, + EnergyTransmissionDistanceUpdate, +) + +router = APIRouter() + + +@router.get("/", response_model=List[EnergyTransmissionDistance]) +def get_all_transmission_distances( + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100, +) -> List[EnergyTransmissionDistance]: + """ + Retrieve all transmission distances. + """ + return crud.energy_transmission_distance.get_multi(db=db, skip=skip, limit=limit) + + +@router.get("/{distance_id}", response_model=EnergyTransmissionDistance) +def get_transmission_distance( + distance_id: int, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Retrieve a transmission distance. + """ + # TODO Check if user has permission for dataset and EnergyTransmissionDistance + return crud.energy_transmission_distance.get(db=db, id=distance_id) + + +@router.get("/component/{component_id}", response_model=List[EnergyTransmissionDistance]) +def get_transmission_distances_by_component( + component_id: int, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +) -> Optional[List[EnergyTransmissionDistance]]: + """ + Retrieve all transmission distances for a given component. + """ + # TODO Check if user has permission for dataset and EnergyTransmissionDistance + return crud.energy_transmission_distance.get_by_component(db=db, component_id=component_id) + + +@router.post("/", response_model=EnergyTransmissionDistance) +def create_transmission_distance( + request: EnergyTransmissionDistanceCreate, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Create a new transmission distance. + """ + dataset = crud.dataset.get(db=db, id=request.ref_dataset) + if dataset is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Dataset {request.ref_dataset} not found!") + + permissions.check_modification_permission(db, user=current, dataset_id=dataset.id) + + component = crud.energy_component.get_by_dataset_and_name(db=db, dataset_id=dataset.id, name=request.component) + if component is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"Component {request.component} not found in dataset {dataset.id}!" + ) + + region_from = crud.region.get_by_dataset_and_name(db=db, dataset_id=dataset.id, name=request.region_from) + if region_from is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Region {request.region_from} not found in dataset {dataset.id}!") + + region_to = crud.region.get_by_dataset_and_name(db=db, dataset_id=dataset.id, name=request.region_to) + if region_to is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Region {request.region_to} not found in dataset {dataset.id}!") + + distance_entry = crud.energy_transmission_distance.get_by_component_and_region_ids( + db=db, + component_id=component.id, + region_from_id=region_from.id, + region_to_id=region_to.id, + ) + if distance_entry is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"EnergyTransmissionDistance for component {component.name} (id {component.id}) from region {region_from.name} (id {region_from.id}) to region {region_to.name} (id {region_to.id}) already exists with id {distance_entry.id}!", # noqa: E501 + ) + + return crud.energy_transmission_distance.create(db=db, obj_in=request) + + +@router.put("/{distance_id}", response_model=EnergyTransmissionDistance) +def update_transmission_distance( + distance_id: int, + request: EnergyTransmissionDistanceUpdate, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Update a transmission distance. + """ + distance = crud.energy_transmission_distance.get(db=db, id=distance_id) + if distance is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"EnergyTransmissionDistance {distance_id} not found!") + + permissions.check_modification_permission(db, user=current, dataset_id=distance.transmission.component.ref_dataset) + return crud.energy_transmission_distance.update(db=db, db_obj=distance, obj_in=request) + + +@router.delete("/{distance_id}", response_model=EnergyTransmissionDistance) +def remove_transmission_distance( + distance_id: int, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Delete a transmission distance. + """ + distance = crud.energy_transmission_distance.get(db=db, id=distance_id) + if distance is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"EnergyTransmissionDistance {distance_id} not found!") + permissions.check_modification_permission(db, user=current, dataset_id=distance.transmission.component.ref_dataset) + return crud.energy_transmission_distance.remove(db=db, id=distance_id) + + +@router.delete("/component/{component_id}", response_model=List[EnergyTransmissionDistance]) +def remove_transmission_distances_by_component( + component_id: int, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Delete all transmission distances for a given component. + """ + distances = crud.energy_transmission_distance.get_by_component(db=db, component_id=component_id) + if distances is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"EnergyTransmissionDistance for component {component_id} not found!") + + # TODO Check if user has permission for dataset and EnergyTransmissionDistance + return crud.energy_transmission_distance.remove_by_component(db=db, component_id=component_id) diff --git a/ensysmod/api/endpoints/energy_transmission_losses.py b/ensysmod/api/endpoints/energy_transmission_losses.py new file mode 100644 index 0000000..2abae2b --- /dev/null +++ b/ensysmod/api/endpoints/energy_transmission_losses.py @@ -0,0 +1,148 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from ensysmod import crud, model +from ensysmod.api import deps, permissions +from ensysmod.schemas import ( + EnergyTransmissionLoss, + EnergyTransmissionLossCreate, + EnergyTransmissionLossUpdate, +) + +router = APIRouter() + + +@router.get("/", response_model=List[EnergyTransmissionLoss]) +def get_all_transmission_losses( + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100, +) -> List[EnergyTransmissionLoss]: + """ + Retrieve all transmission losses. + """ + return crud.energy_transmission_loss.get_multi(db=db, skip=skip, limit=limit) + + +@router.get("/{loss_id}", response_model=EnergyTransmissionLoss) +def get_transmission_loss( + loss_id: int, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Retrieve a transmission loss. + """ + # TODO Check if user has permission for dataset and EnergyTransmissionLoss + return crud.energy_transmission_loss.get(db=db, id=loss_id) + + +@router.get("/component/{component_id}", response_model=List[EnergyTransmissionLoss]) +def get_transmission_losses_by_component( + component_id: int, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +) -> Optional[List[EnergyTransmissionLoss]]: + """ + Retrieve all transmission losses for a given component. + """ + # TODO Check if user has permission for dataset and EnergyTransmissionLoss + return crud.energy_transmission_loss.get_by_component(db=db, component_id=component_id) + + +@router.post("/", response_model=EnergyTransmissionLoss) +def create_transmission_loss( + request: EnergyTransmissionLossCreate, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Create a new transmission loss. + """ + dataset = crud.dataset.get(db=db, id=request.ref_dataset) + if dataset is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Dataset {request.ref_dataset} not found!") + + permissions.check_modification_permission(db, user=current, dataset_id=dataset.id) + + component = crud.energy_component.get_by_dataset_and_name(db=db, dataset_id=dataset.id, name=request.component) + if component is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"Component {request.component} not found in dataset {dataset.id}!" + ) + + region_from = crud.region.get_by_dataset_and_name(db=db, dataset_id=dataset.id, name=request.region_from) + if region_from is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Region {request.region_from} not found in dataset {dataset.id}!") + + region_to = crud.region.get_by_dataset_and_name(db=db, dataset_id=dataset.id, name=request.region_to) + if region_to is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Region {request.region_to} not found in dataset {dataset.id}!") + + loss_entry = crud.energy_transmission_loss.get_by_component_and_region_ids( + db=db, + component_id=component.id, + region_from_id=region_from.id, + region_to_id=region_to.id, + ) + if loss_entry is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"EnergyTransmissionLoss for component {component.name} (id {component.id}) from region {region_from.name} (id {region_from.id}) to region {region_to.name} (id {region_to.id}) already exists with id {loss_entry.id}!", # noqa: E501 + ) + + return crud.energy_transmission_loss.create(db=db, obj_in=request) + + +@router.put("/{loss_id}", response_model=EnergyTransmissionLoss) +def update_transmission_loss( + loss_id: int, + request: EnergyTransmissionLossUpdate, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Update a transmission loss. + """ + loss = crud.energy_transmission_loss.get(db=db, id=loss_id) + if loss is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"EnergyTransmissionLoss {loss_id} not found!") + + permissions.check_modification_permission(db, user=current, dataset_id=loss.transmission.component.ref_dataset) + return crud.energy_transmission_loss.update(db=db, db_obj=loss, obj_in=request) + + +@router.delete("/{loss_id}", response_model=EnergyTransmissionLoss) +def remove_transmission_loss( + loss_id: int, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Delete a transmission loss. + """ + loss = crud.energy_transmission_loss.get(db=db, id=loss_id) + if loss is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"EnergyTransmissionLoss {loss_id} not found!") + permissions.check_modification_permission(db, user=current, dataset_id=loss.transmission.component.ref_dataset) + return crud.energy_transmission_loss.remove(db=db, id=loss_id) + + +@router.delete("/component/{component_id}", response_model=List[EnergyTransmissionLoss]) +def remove_transmission_losses_by_component( + component_id: int, + db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), +): + """ + Delete all transmission losses for a given component. + """ + losses = crud.energy_transmission_loss.get_by_component(db=db, component_id=component_id) + if losses is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"EnergyTransmissionLoss for component {component_id} not found!") + + # TODO Check if user has permission for dataset and EnergyTransmissionLoss + return crud.energy_transmission_loss.remove_by_component(db=db, component_id=component_id) diff --git a/ensysmod/core/file_download.py b/ensysmod/core/file_download.py index 7fb2721..094469b 100644 --- a/ensysmod/core/file_download.py +++ b/ensysmod/core/file_download.py @@ -1,8 +1,8 @@ +import json import os import zipfile -from typing import Any, Type, Set, Dict, List +from typing import Any, Dict, List, Set, Type -import json from pydantic import BaseModel from sqlalchemy.orm import Session @@ -115,14 +115,12 @@ def dump_energy_components(db: Session, dataset_id: int, temp_folder: str, crud_ # dump excel files dump_excel_file(db, obj.ref_component, region_ids, crud.capacity_fix, f"{obj_folder}/capacityFix.xlsx") dump_excel_file(db, obj.ref_component, region_ids, crud.capacity_max, f"{obj_folder}/capacityMax.xlsx") - dump_excel_file(db, obj.ref_component, region_ids, crud.operation_rate_fix, - f"{obj_folder}/operationRateFix.xlsx") - dump_excel_file(db, obj.ref_component, region_ids, crud.operation_rate_max, - f"{obj_folder}/operationRateMax.xlsx") + dump_excel_file(db, obj.ref_component, region_ids, crud.operation_rate_fix, f"{obj_folder}/operationRateFix.xlsx") + dump_excel_file(db, obj.ref_component, region_ids, crud.operation_rate_max, f"{obj_folder}/operationRateMax.xlsx") if file_name == "transmission": - crud.energy_transmission_distance.get_dataframe(db, obj.ref_component, region_ids) \ - .to_excel(f"{obj_folder}/distances.xlsx") + crud.energy_transmission_distance.get_dataframe(db, obj.ref_component, region_ids).to_excel(f"{obj_folder}/distances.xlsx") + crud.energy_transmission_loss.get_dataframe(db, obj.ref_component, region_ids).to_excel(f"{obj_folder}/losses.xlsx") def dump_json(file: str, fields: Set[str], obj: Any): diff --git a/ensysmod/core/file_upload.py b/ensysmod/core/file_upload.py index 7a9672b..67c8a83 100644 --- a/ensysmod/core/file_upload.py +++ b/ensysmod/core/file_upload.py @@ -1,6 +1,6 @@ import json from tempfile import TemporaryFile -from typing import List, Dict, Any, Type +from typing import Any, Dict, List, Type from zipfile import ZipFile import pandas as pd @@ -11,7 +11,7 @@ from ensysmod.crud.base_depends_component import CRUDBaseDependsComponent from ensysmod.crud.base_depends_dataset import CRUDBaseDependsDataset from ensysmod.crud.base_depends_timeseries import CRUDBaseDependsTimeSeries -from ensysmod.schemas import ZipArchiveUploadResult, FileStatus, FileUploadResult +from ensysmod.schemas import FileStatus, FileUploadResult, ZipArchiveUploadResult def create_or_update_named_entity(crud_repo: CRUDBaseDependsDataset, db: Session, request: Any): @@ -190,7 +190,6 @@ def process_sub_folder_files(zip_archive: ZipFile, sub_folder_name: str, db: Ses # check if operationRateFix.xlsx exists in sub_folder_name if sub_folder_name + "operationRateFix.xlsx" in zip_archive.namelist(): - # process operationRateFix.xlsx file_results.append(process_excel_file(zip_archive.open(sub_folder_name + "operationRateFix.xlsx"), db, dataset_id, component_name, "fix_operation_rates", @@ -199,7 +198,6 @@ def process_sub_folder_files(zip_archive: ZipFile, sub_folder_name: str, db: Ses # check if operationRateMax.xlsx exists in sub_folder_name if sub_folder_name + "operationRateMax.xlsx" in zip_archive.namelist(): - # process operationRateFix.xlsx file_results.append(process_excel_file(zip_archive.open(sub_folder_name + "operationRateMax.xlsx"), db, dataset_id, component_name, "max_operation_rates", @@ -208,7 +206,6 @@ def process_sub_folder_files(zip_archive: ZipFile, sub_folder_name: str, db: Ses # check if capacityFix.xlsx exists in sub_folder_name if sub_folder_name + "capacityFix.xlsx" in zip_archive.namelist(): - # process operationRateFix.xlsx file_results.append(process_excel_file(zip_archive.open(sub_folder_name + "capacityFix.xlsx"), db, dataset_id, component_name, "fix_capacities", @@ -217,7 +214,6 @@ def process_sub_folder_files(zip_archive: ZipFile, sub_folder_name: str, db: Ses # check if capacityMax.xlsx exists in sub_folder_name if sub_folder_name + "capacityMax.xlsx" in zip_archive.namelist(): - # process operationRateFix.xlsx file_results.append(process_excel_file(zip_archive.open(sub_folder_name + "capacityMax.xlsx"), db, dataset_id, component_name, "max_capacities", @@ -232,7 +228,6 @@ def process_sub_folder_matrix_files(zip_archive: ZipFile, sub_folder_name: str, # check if capacityFix.xlsx exists in sub_folder_name if sub_folder_name + "capacityFix.xlsx" in zip_archive.namelist(): - # process operationRateFix.xlsx file_results.append(process_matrix_excel_file(zip_archive.open(sub_folder_name + "capacityFix.xlsx"), db, dataset_id, component_name, "fix_capacities", @@ -241,7 +236,6 @@ def process_sub_folder_matrix_files(zip_archive: ZipFile, sub_folder_name: str, # check if capacityMax.xlsx exists in sub_folder_name if sub_folder_name + "capacityMax.xlsx" in zip_archive.namelist(): - # process operationRateFix.xlsx file_results.append(process_matrix_excel_file(zip_archive.open(sub_folder_name + "capacityMax.xlsx"), db, dataset_id, component_name, "max_capacities", @@ -250,13 +244,21 @@ def process_sub_folder_matrix_files(zip_archive: ZipFile, sub_folder_name: str, # check if distances.xlsx exists in sub_folder_name if sub_folder_name + "distances.xlsx" in zip_archive.namelist(): - # process operationRateFix.xlsx file_results.append(process_matrix_excel_file(zip_archive.open(sub_folder_name + "distances.xlsx"), db, dataset_id, component_name, "distance", crud_repo=crud.energy_transmission_distance, create_model=schemas.EnergyTransmissionDistanceCreate, as_list=False, region_key="region_from")) + + # check if losses.xlsx exists in sub_folder_name + if sub_folder_name + "losses.xlsx" in zip_archive.namelist(): + file_results.append(process_matrix_excel_file(zip_archive.open(sub_folder_name + "losses.xlsx"), + db, dataset_id, + component_name, "loss", + crud_repo=crud.energy_transmission_loss, + create_model=schemas.EnergyTransmissionLossCreate, + as_list=False, region_key="region_from")) return file_results diff --git a/ensysmod/core/fine_esm.py b/ensysmod/core/fine_esm.py index e9a47f8..b7428cd 100644 --- a/ensysmod/core/fine_esm.py +++ b/ensysmod/core/fine_esm.py @@ -143,8 +143,8 @@ def add_transmission(esM: EnergySystemModel, db: Session, transmission: EnergyTr region_ids: List[int], custom_parameters: List[EnergyModelOverride]) -> None: esm_transmission = component_to_dict(db, transmission.component, region_ids) esm_transmission["commodity"] = transmission.commodity.name - esm_transmission["distances"] = crud.energy_transmission_distance.get_dataframe(db, transmission.ref_component, - region_ids=region_ids) + esm_transmission["distances"] = crud.energy_transmission_distance.get_dataframe(db, transmission.ref_component, region_ids=region_ids) + esm_transmission["losses"] = crud.energy_transmission_loss.get_dataframe(db, transmission.ref_component, region_ids=region_ids) esm_transmission = override_parameters(esm_transmission, custom_parameters) esM.add(Transmission(esM=esM, **esm_transmission)) diff --git a/ensysmod/crud/__init__.py b/ensysmod/crud/__init__.py index 9065f87..cc5a104 100644 --- a/ensysmod/crud/__init__.py +++ b/ensysmod/crud/__init__.py @@ -15,6 +15,7 @@ from .energy_storage import energy_storage from .energy_transmission import energy_transmission from .energy_transmission_distance import energy_transmission_distance +from .energy_transmission_loss import energy_transmission_loss from .region import region from .ts_capacity_fix import capacity_fix from .ts_capacity_max import capacity_max diff --git a/ensysmod/crud/energy_transmission.py b/ensysmod/crud/energy_transmission.py index bd81012..d17fe0e 100644 --- a/ensysmod/crud/energy_transmission.py +++ b/ensysmod/crud/energy_transmission.py @@ -21,13 +21,6 @@ def create(self, db: Session, *, obj_in: EnergyTransmissionCreate) -> EnergyTran obj_in_dict['ref_commodity'] = commodity.id db_obj = super().create(db=db, obj_in=obj_in_dict) - # also create distances - if obj_in.distances is not None: - for distance_create in obj_in.distances: - distance_create.ref_dataset = obj_in.ref_dataset - distance_create.ref_component = db_obj.component.id - crud.energy_transmission_distance.create(db, obj_in=distance_create) - return db_obj diff --git a/ensysmod/crud/energy_transmission_distance.py b/ensysmod/crud/energy_transmission_distance.py index 0350efa..0b42e58 100644 --- a/ensysmod/crud/energy_transmission_distance.py +++ b/ensysmod/crud/energy_transmission_distance.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import pandas as pd from sqlalchemy.orm import Session @@ -6,25 +6,63 @@ from ensysmod import crud from ensysmod.crud.base import CRUDBase from ensysmod.model import EnergyTransmissionDistance -from ensysmod.schemas import EnergyTransmissionDistanceCreate, EnergyTransmissionDistanceUpdate +from ensysmod.schemas import ( + EnergyTransmissionDistanceCreate, + EnergyTransmissionDistanceUpdate, +) # noinspection PyMethodMayBeStatic,PyArgumentList -class CRUDEnergyTransmissionDistance(CRUDBase[EnergyTransmissionDistance, - EnergyTransmissionDistanceCreate, - EnergyTransmissionDistanceUpdate]): +class CRUDEnergyTransmissionDistance(CRUDBase[EnergyTransmissionDistance, EnergyTransmissionDistanceCreate, EnergyTransmissionDistanceUpdate]): """ CRUD operations for EnergyTransmissionDistance """ - def remove_by_component(self, db: Session, component_id: int): + def get_by_component(self, db: Session, component_id: int) -> Optional[List[EnergyTransmissionDistance]]: """ - Removes all EnergyTransmissionDistance entries for a given component. + Get all EnergyTransmissionDistance entries for a given component. + """ + return db.query(self.model).filter(self.model.ref_component == component_id).all() - :param db: Database session - :param component_id: ID of the component + def get_by_component_and_region_ids( + self, db: Session, component_id: int, region_from_id: int, region_to_id: int + ) -> Optional[EnergyTransmissionDistance]: + """ + Get a EnergyTransmissionDistance entry for a given component id and its two region ids. + """ + return ( + db.query(self.model) + .filter(self.model.ref_component == component_id) + .filter(self.model.ref_region_from == region_from_id) + .filter(self.model.ref_region_to == region_to_id) + .first() + ) + + def get_by_dataset_id_component_region_names( + self, db: Session, dataset_id: int, component_name: str, region_from_name: str, region_to_name: str + ) -> Optional[EnergyTransmissionDistance]: + """ + Get a EnergyTransmissionDistance entry for a given dataset id, component name and its two region names. + """ + component = crud.energy_component.get_by_dataset_and_name(db=db, dataset_id=dataset_id, name=component_name) + region_from = crud.region.get_by_dataset_and_name(db=db, dataset_id=dataset_id, name=region_from_name) + region_to = crud.region.get_by_dataset_and_name(db=db, dataset_id=dataset_id, name=region_to_name) + return ( + db.query(self.model) + .filter(self.model.ref_component == component.id) + .filter(self.model.ref_region_from == region_from.id) + .filter(self.model.ref_region_to == region_to.id) + .first() + ) + + def remove_by_component(self, db: Session, component_id: int) -> Optional[List[EnergyTransmissionDistance]]: + """ + Removes all EnergyTransmissionDistance entries for a given component. """ - db.query(EnergyTransmissionDistance).filter(EnergyTransmissionDistance.ref_component == component_id).delete() + obj = db.query(self.model).filter(self.model.ref_component == component_id).all() + db.query(self.model).filter(self.model.ref_component == component_id).delete() + db.commit() + return obj def create(self, db: Session, obj_in: EnergyTransmissionDistanceCreate) -> EnergyTransmissionDistance: """ @@ -34,53 +72,40 @@ def create(self, db: Session, obj_in: EnergyTransmissionDistanceCreate) -> Energ :param obj_in: Input data :return: New energy transmission distance entry """ - - if obj_in.ref_component is None and obj_in.component is None: - raise ValueError("Component must be specified. Provide reference id or component name.") - - if obj_in.ref_region_from is None and obj_in.region_from is None: - raise ValueError("Region from must be specified. Provide reference id or region name.") - - if obj_in.ref_component is not None: - transmission = crud.energy_transmission.get(db, obj_in.ref_component) - else: - transmission = crud.energy_transmission.get_by_dataset_and_name(db, dataset_id=obj_in.ref_dataset, - name=obj_in.component) - + transmission = crud.energy_transmission.get_by_dataset_and_name(db, dataset_id=obj_in.ref_dataset, name=obj_in.component) if transmission is None or transmission.component.ref_dataset != obj_in.ref_dataset: raise ValueError("Component not found or from different dataset.") - obj_in.ref_component = transmission.ref_component - - if obj_in.ref_region_from is not None: - region_from = crud.region.get(db, obj_in.ref_region_from) - else: - region_from = crud.region.get_by_dataset_and_name(db, dataset_id=obj_in.ref_dataset, - name=obj_in.region_from) + region_from = crud.region.get_by_dataset_and_name(db, dataset_id=obj_in.ref_dataset, name=obj_in.region_from) if region_from is None or region_from.ref_dataset != obj_in.ref_dataset: - raise ValueError("Region from not found or from different dataset.") - obj_in.ref_region_from = region_from.id - - if obj_in.ref_region_to is not None: - region_to = crud.region.get(db, obj_in.ref_region_to) - else: - region_to = crud.region.get_by_dataset_and_name(db, dataset_id=obj_in.ref_dataset, name=obj_in.region_to) + raise ValueError("Origin region not found or from different dataset.") + region_to = crud.region.get_by_dataset_and_name(db, dataset_id=obj_in.ref_dataset, name=obj_in.region_to) if region_to is None or region_to.ref_dataset != obj_in.ref_dataset: - raise ValueError("Region to not found or from different dataset.") - obj_in.ref_region_to = region_to.id - - return super().create(db=db, obj_in=obj_in) + raise ValueError("Target region not found or from different dataset.") + + db_obj = EnergyTransmissionDistance( + distance=obj_in.distance, + ref_component=transmission.ref_component, + ref_region_from=region_from.id, + ref_region_to=region_to.id, + ) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj def get_dataframe(self, db: Session, component_id: int, region_ids: List[int]) -> pd.DataFrame: """ Returns the distances for the provided regions as matrix. """ - data = db.query(self.model) \ - .filter(self.model.ref_component == component_id) \ - .filter(self.model.ref_region_from.in_(region_ids)) \ - .filter(self.model.ref_region_to.in_(region_ids)) \ + data = ( + db.query(self.model) + .filter(self.model.ref_component == component_id) + .filter(self.model.ref_region_from.in_(region_ids)) + .filter(self.model.ref_region_to.in_(region_ids)) .all() + ) region_names = [crud.region.get(db, id=r_id).name for r_id in region_ids] df = pd.DataFrame(0.0, index=region_names, columns=region_names) diff --git a/ensysmod/crud/energy_transmission_loss.py b/ensysmod/crud/energy_transmission_loss.py new file mode 100644 index 0000000..fb1c34a --- /dev/null +++ b/ensysmod/crud/energy_transmission_loss.py @@ -0,0 +1,115 @@ +from typing import List, Optional + +import pandas as pd +from sqlalchemy.orm import Session + +from ensysmod import crud +from ensysmod.crud.base import CRUDBase +from ensysmod.model import EnergyTransmissionLoss +from ensysmod.schemas import EnergyTransmissionLossCreate, EnergyTransmissionLossUpdate + + +# noinspection PyMethodMayBeStatic,PyArgumentList +class CRUDEnergyTransmissionLoss(CRUDBase[EnergyTransmissionLoss, EnergyTransmissionLossCreate, EnergyTransmissionLossUpdate]): + """ + CRUD operations for EnergyTransmissionLoss + """ + + def get_by_component(self, db: Session, component_id: int) -> Optional[List[EnergyTransmissionLoss]]: + """ + Get all EnergyTransmissionLoss entries for a given component. + """ + return db.query(self.model).filter(self.model.ref_component == component_id).all() + + def get_by_component_and_region_ids( + self, db: Session, component_id: int, region_from_id: int, region_to_id: int + ) -> Optional[EnergyTransmissionLoss]: + """ + Get a EnergyTransmissionLoss entry for a given component id and its two region ids. + """ + return ( + db.query(self.model) + .filter(self.model.ref_component == component_id) + .filter(self.model.ref_region_from == region_from_id) + .filter(self.model.ref_region_to == region_to_id) + .first() + ) + + def get_by_dataset_id_component_region_names( + self, db: Session, dataset_id: int, component_name: str, region_from_name: str, region_to_name: str + ) -> Optional[EnergyTransmissionLoss]: + """ + Get a EnergyTransmissionLoss entry for a given dataset id, component name and its two region names. + """ + component = crud.energy_component.get_by_dataset_and_name(db=db, dataset_id=dataset_id, name=component_name) + region_from = crud.region.get_by_dataset_and_name(db=db, dataset_id=dataset_id, name=region_from_name) + region_to = crud.region.get_by_dataset_and_name(db=db, dataset_id=dataset_id, name=region_to_name) + return ( + db.query(self.model) + .filter(self.model.ref_component == component.id) + .filter(self.model.ref_region_from == region_from.id) + .filter(self.model.ref_region_to == region_to.id) + .first() + ) + + def remove_by_component(self, db: Session, component_id: int) -> Optional[List[EnergyTransmissionLoss]]: + """ + Removes all EnergyTransmissionLoss entries for a given component. + """ + obj = db.query(self.model).filter(self.model.ref_component == component_id).all() + db.query(self.model).filter(self.model.ref_component == component_id).delete() + db.commit() + return obj + + def create(self, db: Session, obj_in: EnergyTransmissionLossCreate) -> EnergyTransmissionLoss: + """ + Creates a new energy transmission loss entry between two regions. + + :param db: Database session + :param obj_in: Input data + :return: New energy transmission loss entry + """ + + transmission = crud.energy_transmission.get_by_dataset_and_name(db, dataset_id=obj_in.ref_dataset, name=obj_in.component) + if transmission is None or transmission.component.ref_dataset != obj_in.ref_dataset: + raise ValueError("Component not found or from different dataset.") + + region_from = crud.region.get_by_dataset_and_name(db, dataset_id=obj_in.ref_dataset, name=obj_in.region_from) + if region_from is None or region_from.ref_dataset != obj_in.ref_dataset: + raise ValueError("Origin region not found or from different dataset.") + + region_to = crud.region.get_by_dataset_and_name(db, dataset_id=obj_in.ref_dataset, name=obj_in.region_to) + if region_to is None or region_to.ref_dataset != obj_in.ref_dataset: + raise ValueError("Target region not found or from different dataset.") + + db_obj = EnergyTransmissionLoss( + loss=obj_in.loss, + ref_component=transmission.ref_component, + ref_region_from=region_from.id, + ref_region_to=region_to.id, + ) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def get_dataframe(self, db: Session, component_id: int, region_ids: List[int]) -> pd.DataFrame: + """ + Returns the losses for the provided regions as matrix. + """ + data = ( + db.query(self.model) + .filter(self.model.ref_component == component_id) + .filter(self.model.ref_region_from.in_(region_ids)) + .filter(self.model.ref_region_to.in_(region_ids)) + .all() + ) + + region_names = [crud.region.get(db, id=r_id).name for r_id in region_ids] + df = pd.DataFrame(0.0, index=region_names, columns=region_names) + for d in data: + df[d.region_to.name][d.region_from.name] = d.loss + return df + + +energy_transmission_loss = CRUDEnergyTransmissionLoss(EnergyTransmissionLoss) diff --git a/ensysmod/model/__init__.py b/ensysmod/model/__init__.py index 4947c5a..d1e13d2 100644 --- a/ensysmod/model/__init__.py +++ b/ensysmod/model/__init__.py @@ -23,6 +23,7 @@ from .energy_storage import EnergyStorage from .energy_transmission import EnergyTransmission from .energy_transmission_distance import EnergyTransmissionDistance +from .energy_transmission_loss import EnergyTransmissionLoss from .region import Region from .ts_capacity_fix import CapacityFix from .ts_capacity_max import CapacityMax diff --git a/ensysmod/model/energy_transmission.py b/ensysmod/model/energy_transmission.py index 563862c..b2f25e2 100644 --- a/ensysmod/model/energy_transmission.py +++ b/ensysmod/model/energy_transmission.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, ForeignKey, Float +from sqlalchemy import Column, ForeignKey, Integer from sqlalchemy.orm import relationship from ensysmod.database.base_class import Base @@ -8,15 +8,15 @@ class EnergyTransmission(Base): """ EnergyTransmission table definition - See https://vsa-fine.readthedocs.io/en/latest/storageClassDoc.html + See https://vsa-fine.readthedocs.io/en/master/sourceCodeDocumentation/components/transmissionClassDoc.html """ + ref_component = Column(Integer, ForeignKey("energy_component.id"), index=True, nullable=False, primary_key=True) ref_commodity = Column(Integer, ForeignKey("energy_commodity.id"), index=True, nullable=False) - loss_per_unit = Column(Float, nullable=True) - # Relationships component = relationship("EnergyComponent") commodity = relationship("EnergyCommodity", back_populates="energy_transmissions") distances = relationship("EnergyTransmissionDistance", back_populates="transmission") + losses = relationship("EnergyTransmissionLoss", back_populates="transmission") diff --git a/ensysmod/model/energy_transmission_distance.py b/ensysmod/model/energy_transmission_distance.py index d3a7fb5..343d520 100644 --- a/ensysmod/model/energy_transmission_distance.py +++ b/ensysmod/model/energy_transmission_distance.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, ForeignKey, UniqueConstraint, Float +from sqlalchemy import Column, Float, ForeignKey, Integer, UniqueConstraint from sqlalchemy.orm import relationship from ensysmod.database.base_class import Base @@ -6,10 +6,11 @@ class EnergyTransmissionDistance(Base): """ - EnergyTransmission table definition + EnergyTransmissionDistance table definition - See https://vsa-fine.readthedocs.io/en/latest/storageClassDoc.html + See https://vsa-fine.readthedocs.io/en/master/sourceCodeDocumentation/components/transmissionClassDoc.html """ + id = Column(Integer, primary_key=True) ref_component = Column(Integer, ForeignKey("energy_transmission.ref_component"), index=True, nullable=False) ref_region_from = Column(Integer, ForeignKey("region.id"), index=True, nullable=False) @@ -23,7 +24,4 @@ class EnergyTransmissionDistance(Base): region_to = relationship("Region", foreign_keys=[ref_region_to]) # table constraints - __table_args__ = ( - UniqueConstraint("ref_component", "ref_region_from", "ref_region_to", - name="_transmission_distances_regions_uc"), - ) + __table_args__ = (UniqueConstraint("ref_component", "ref_region_from", "ref_region_to", name="_transmission_distances_regions_uc"),) diff --git a/ensysmod/model/energy_transmission_loss.py b/ensysmod/model/energy_transmission_loss.py new file mode 100644 index 0000000..8c5d265 --- /dev/null +++ b/ensysmod/model/energy_transmission_loss.py @@ -0,0 +1,27 @@ +from sqlalchemy import Column, Float, ForeignKey, Integer, UniqueConstraint +from sqlalchemy.orm import relationship + +from ensysmod.database.base_class import Base + + +class EnergyTransmissionLoss(Base): + """ + EnergyTransmissionLoss table definition + + See https://vsa-fine.readthedocs.io/en/master/sourceCodeDocumentation/components/transmissionClassDoc.html + """ + + id = Column(Integer, primary_key=True) + ref_component = Column(Integer, ForeignKey("energy_transmission.ref_component"), index=True, nullable=False) + ref_region_from = Column(Integer, ForeignKey("region.id"), index=True, nullable=False) + ref_region_to = Column(Integer, ForeignKey("region.id"), index=True, nullable=False) + + loss = Column(Float, nullable=True) + + # Relationships + transmission = relationship("EnergyTransmission", back_populates="losses") + region_from = relationship("Region", foreign_keys=[ref_region_from]) + region_to = relationship("Region", foreign_keys=[ref_region_to]) + + # table constraints + __table_args__ = (UniqueConstraint("ref_component", "ref_region_from", "ref_region_to", name="_transmission_distances_regions_uc"),) diff --git a/ensysmod/schemas/__init__.py b/ensysmod/schemas/__init__.py index 05a3363..575537b 100644 --- a/ensysmod/schemas/__init__.py +++ b/ensysmod/schemas/__init__.py @@ -51,6 +51,11 @@ EnergyTransmissionDistanceCreate, EnergyTransmissionDistanceUpdate, ) +from .energy_transmission_loss import ( + EnergyTransmissionLoss, + EnergyTransmissionLossCreate, + EnergyTransmissionLossUpdate, +) from .file_upload import FileStatus, FileUploadResult, ZipArchiveUploadResult from .region import Region, RegionCreate, RegionUpdate from .token import Token, TokenPayload diff --git a/ensysmod/schemas/energy_transmission.py b/ensysmod/schemas/energy_transmission.py index 129d830..ee79897 100644 --- a/ensysmod/schemas/energy_transmission.py +++ b/ensysmod/schemas/energy_transmission.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel, Field from pydantic.class_validators import validator @@ -10,10 +10,6 @@ EnergyComponentCreate, EnergyComponentUpdate, ) -from ensysmod.schemas.energy_transmission_distance import ( - EnergyTransmissionDistance, - EnergyTransmissionDistanceCreate, -) from ensysmod.utils import validators @@ -21,51 +17,29 @@ class EnergyTransmissionBase(BaseModel): """ Shared attributes for an energy transmission. Used as a base class for all schemas. """ + type = EnergyComponentType.TRANSMISSION - loss_per_unit: Optional[float] = Field(None, - description="Loss per length unit of energy transmission.", - example=0.002) # validators _valid_type = validator("type", allow_reuse=True)(validators.validate_energy_component_type) - _valid_loss_per_unit = validator("loss_per_unit", allow_reuse=True)(validators.validate_loss_per_unit) class EnergyTransmissionCreate(EnergyTransmissionBase, EnergyComponentCreate): """ Attributes to receive via API on creation of an energy transmission. """ - commodity: str = Field(..., - description="Commodity of energy transmission.", - example="electricity") - distances: Optional[List[EnergyTransmissionDistanceCreate]] \ - = Field(None, - description="Distances of energy transmission in the length unit provided with the dataset.") + + commodity: str = Field(..., description="Commodity of energy transmission.", example="electricity") # validators - _valid_distances = validator("distances", allow_reuse=True)(validators.validate_distances) _valid_commodity = validator("commodity", allow_reuse=True)(validators.validate_commodity) - class Config: - schema_extra = { - "example": { - "loss_per_unit": 0.002, - "commodity": "electricity", - "distances": [ - { - "region_from": "germany", - "region_to": "france", - "distance": 135.4 - } - ] - } - } - class EnergyTransmissionUpdate(EnergyTransmissionBase, EnergyComponentUpdate): """ Attributes to receive via API on update of an energy transmission. """ + commodity: Optional[str] = None # validators @@ -76,9 +50,9 @@ class EnergyTransmission(EnergyTransmissionBase): """ Attributes to return via API for an energy transmission. """ + component: EnergyComponent commodity: EnergyCommodity - distances: List[EnergyTransmissionDistance] class Config: orm_mode = True diff --git a/ensysmod/schemas/energy_transmission_distance.py b/ensysmod/schemas/energy_transmission_distance.py index 572eed7..fb3c232 100644 --- a/ensysmod/schemas/energy_transmission_distance.py +++ b/ensysmod/schemas/energy_transmission_distance.py @@ -1,8 +1,6 @@ -from typing import Optional - from pydantic import BaseModel, Field, validator -from pydantic.class_validators import root_validator +from ensysmod.schemas.energy_transmission import EnergyTransmission from ensysmod.schemas.region import Region from ensysmod.utils import validators @@ -11,6 +9,7 @@ class EnergyTransmissionDistanceBase(BaseModel): """ Shared attributes for an energy transmission distance. Used as a base class for all schemas. """ + distance: float = Field(..., description="Distance between two regions in unit of dataset.", example=133.4) # validators @@ -21,41 +20,30 @@ class EnergyTransmissionDistanceCreate(EnergyTransmissionDistanceBase): """ Attributes to receive via API on creation of an energy transmission distance. """ - ref_dataset: Optional[int] = Field(None, description="Reference dataset ID. The current dataset will be used.") - - ref_component: Optional[int] = Field(None, description="Reference component ID. " - "The current component will be used.") - component: Optional[str] = Field(None, description="Component name. If no ref_component is provided, the name is " - "used to find the component.") - ref_region_from: Optional[int] = Field(None, description="Reference region ID.") - region_from: Optional[str] = Field(None, description="Region name. If no ref_region_from is provided, the name is " - "used to find the region.") - - ref_region_to: Optional[int] = Field(None, description="Reference region ID.") - region_to: Optional[str] = Field(None, description="Region name. If no ref_region_to is provided, the name is " - "used to find the region.") + ref_dataset: int = Field(..., description="The ID of the referenced dataset.") + component: str = Field(..., description="The name of the transmission component.") + region_from: str = Field(..., description="The name of the origin region.") + region_to: str = Field(..., description="The name of the target region.") # validators - _valid_ref_dataset = validator("ref_dataset", allow_reuse=True)(validators.validate_ref_dataset_optional) - - # _valid_ref_component = root_validator(allow_reuse=True)(validators.validate_component_or_ref) - _valid_ref_region_from = root_validator(allow_reuse=True)(validators.validate_region_from_or_ref) - _valid_ref_region_to = root_validator(allow_reuse=True)(validators.validate_region_to_or_ref) + _valid_ref_dataset = validator("ref_dataset", allow_reuse=True)(validators.validate_ref_dataset_required) class EnergyTransmissionDistanceUpdate(EnergyTransmissionDistanceBase): """ Attributes to receive via API on update of an energy transmission distance. """ - distance: Optional[float] = None + pass class EnergyTransmissionDistance(EnergyTransmissionDistanceBase): """ Attributes to return via API for an energy transmission distance. """ + id: int + transmission: EnergyTransmission region_from: Region region_to: Region diff --git a/ensysmod/schemas/energy_transmission_loss.py b/ensysmod/schemas/energy_transmission_loss.py new file mode 100644 index 0000000..bd29902 --- /dev/null +++ b/ensysmod/schemas/energy_transmission_loss.py @@ -0,0 +1,51 @@ +from pydantic import BaseModel, Field, validator + +from ensysmod.schemas.energy_transmission import EnergyTransmission +from ensysmod.schemas.region import Region +from ensysmod.utils import validators + + +class EnergyTransmissionLossBase(BaseModel): + """ + Shared attributes for an energy transmission loss. Used as a base class for all schemas. + """ + + loss: float = Field(..., description="Relative loss per length unit of energy transmission.", example=0.00003) + + # validators + _valid_distance = validator("loss", allow_reuse=True)(validators.validate_loss) + + +class EnergyTransmissionLossCreate(EnergyTransmissionLossBase): + """ + Attributes to receive via API on creation of an energy transmission loss. + """ + + ref_dataset: int = Field(..., description="The ID of the referenced dataset.") + component: str = Field(..., description="The name of the transmission component.") + region_from: str = Field(..., description="The name of the origin region.") + region_to: str = Field(..., description="The name of the target region.") + + # validators + _valid_ref_dataset = validator("ref_dataset", allow_reuse=True)(validators.validate_ref_dataset_required) + + +class EnergyTransmissionLossUpdate(EnergyTransmissionLossBase): + """ + Attributes to receive via API on update of an energy transmission loss. + """ + pass + + +class EnergyTransmissionLoss(EnergyTransmissionLossBase): + """ + Attributes to return via API for an energy transmission loss. + """ + + id: int + transmission: EnergyTransmission + region_from: Region + region_to: Region + + class Config: + orm_mode = True diff --git a/ensysmod/utils/validators.py b/ensysmod/utils/validators.py index 42ceef6..15d045a 100644 --- a/ensysmod/utils/validators.py +++ b/ensysmod/utils/validators.py @@ -1,6 +1,5 @@ from typing import Any, List, Optional -from pydantic import root_validator from pydantic.errors import MissingError from ensysmod.model import EnergyComponentType @@ -463,81 +462,17 @@ def validate_distance(distance: float) -> float: return distance -@root_validator -def validate_component_or_ref(cls, values): - component, ref_component = values.get('component'), values.get('ref_component') - - if component is None and ref_component is None: - raise ValueError("Either component or ref_component must be provided.") - - validate_ref_component_optional(ref_component) - - if component is not None and len(component) > 100: - raise ValueError("The component must not be longer than 100 characters.") - - return values - - -@root_validator -def validate_region_to_or_ref(cls, values): - region_to, ref_region_to = values.get('region_to'), values.get('ref_region_to') - - if region_to is None and ref_region_to is None: - raise ValueError("Either region_to or ref_region_to must be provided.") - - if region_to is not None and len(region_to) > 100: - raise ValueError("The region_to must not be longer than 100 characters.") - - if ref_region_to is not None and ref_region_to <= 0: - raise ValueError("Reference to the region_to must be positive.") - - return values - - -@root_validator -def validate_region_from_or_ref(cls, values): - region_from, ref_region_from = values.get('region_from'), values.get('ref_region_from') - - if region_from is None and ref_region_from is None: - raise ValueError("Either region_from or ref_region_from must be provided.") - - if region_from is not None and len(region_from) > 100: - raise ValueError("The region_from must not be longer than 100 characters.") - - if ref_region_from is not None and ref_region_from <= 0: - raise ValueError("Reference to the region_from must be positive.") - - return values - - -def validate_loss_per_unit(loss_per_unit: float) -> Optional[float]: - """ - Validates the loss per unit of an object. - - :param loss_per_unit: The loss per unit of the object. - :return: The validated loss per unit. +def validate_loss(loss: float) -> float: """ - if loss_per_unit is None: - return None - if loss_per_unit < 0 or loss_per_unit > 1: - raise ValueError("The loss per unit must be zero or positive.") - - return loss_per_unit - + Validates the loss of an object. -def validate_distances(distances: List[Any]) -> List[Any]: + :param loss: The loss of the object. + :return: The validated loss. """ - Validates the distances of an object. - - :param distances: The distances of the object. - :return: The validated distances. - """ - if distances is None: - raise MissingError() - if len(distances) == 0: - raise ValueError("List of distances must not be empty.") + if loss < 0 or loss > 1: + raise ValueError("The loss must be between zero and one.") - return distances + return loss def validate_fix_capacities(fix_capacities: List[float]) -> List[float]: diff --git a/tests/api/test_energy_transmission_distances.py b/tests/api/test_energy_transmission_distances.py new file mode 100644 index 0000000..647ec52 --- /dev/null +++ b/tests/api/test_energy_transmission_distances.py @@ -0,0 +1,195 @@ +from typing import Dict + +from fastapi import status +from fastapi.encoders import jsonable_encoder +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from ensysmod.schemas import EnergyTransmissionDistanceUpdate +from tests.utils import data_generator +from tests.utils.assertions import assert_transmission_distance +from tests.utils.data_generator.energy_transmissions import create_transmission_scenario +from tests.utils.utils import clear_database + + +def test_create_transmission_distance(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating a transmission distance. + """ + create_request = data_generator.fixed_transmission_distance_create(db) + response = client.post("/distances/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_200_OK + + created_distance = response.json() + assert created_distance["distance"] == create_request.distance + assert created_distance["transmission"]["component"]["name"] == create_request.component + assert created_distance["region_from"]["name"] == create_request.region_from + assert created_distance["region_to"]["name"] == create_request.region_to + + +def test_create_existing_transmission_distance(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating an existing transmission distance. + """ + clear_database(db) + create_request = data_generator.fixed_transmission_distance_create(db) + response = client.post("/distances/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_200_OK + response = client.post("/distances/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_409_CONFLICT + + +def test_create_transmission_distance_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating a transmission distance with unknown dataset. + """ + create_request = data_generator.fixed_transmission_distance_create(db) + create_request.ref_dataset = 123456 + response = client.post("/distances/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_404_NOT_FOUND + + +def test_create_transmission_distance_unknown_component(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating a transmission distance with unknown component. + """ + create_request = data_generator.fixed_transmission_distance_create(db) + create_request.component = "Unknown Component" + response = client.post("/distances/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_404_NOT_FOUND + + +def test_create_transmission_distance_unknown_regions(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating a transmission distance with unknown regions. + """ + create_request = data_generator.fixed_transmission_distance_create(db) + create_request.region_from = "Unknown Region" + response = client.post("/distances/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_404_NOT_FOUND + + create_request = data_generator.fixed_transmission_distance_create(db) + create_request.region_to = "Unknown Region" + response = client.post("/distances/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_404_NOT_FOUND + + +def test_get_all_transmission_distances(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving all transmission distances. + """ + clear_database(db) + scenario = create_transmission_scenario(db) + + response = client.get("/distances/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + distance_list = response.json() + assert len(distance_list) == 4 + assert_transmission_distance(check_entry=distance_list[0], expected_entry=scenario["distances"][0]) + assert_transmission_distance(check_entry=distance_list[1], expected_entry=scenario["distances"][1]) + assert_transmission_distance(check_entry=distance_list[2], expected_entry=scenario["distances"][2]) + assert_transmission_distance(check_entry=distance_list[3], expected_entry=scenario["distances"][3]) + + +def test_get_transmission_distance(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving a transmission distance. + """ + existing_distance = data_generator.fixed_existing_transmission_distance(db) + response = client.get(f"/distances/{existing_distance.id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + retrieved_distance = response.json() + assert_transmission_distance(check_entry=retrieved_distance, expected_entry=existing_distance) + + +def test_get_transmission_distances_by_component(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving all transmission distances of a component. + """ + clear_database(db) + scenario = create_transmission_scenario(db) + + component_id = scenario["transmissions"][0].component.id + response = client.get(f"/distances/component/{component_id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + distance_list = response.json() + assert len(distance_list) == 2 + assert_transmission_distance(check_entry=distance_list[0], expected_entry=scenario["distances"][0]) + assert_transmission_distance(check_entry=distance_list[1], expected_entry=scenario["distances"][1]) + + component_id = scenario["transmissions"][1].component.id + response = client.get(f"/distances/component/{component_id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + distance_list = response.json() + assert len(distance_list) == 2 + assert_transmission_distance(check_entry=distance_list[0], expected_entry=scenario["distances"][2]) + assert_transmission_distance(check_entry=distance_list[1], expected_entry=scenario["distances"][3]) + + +def test_update_transmission_distance(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): + """ + Test updating a transmission distance. + """ + existing_distance = data_generator.fixed_existing_transmission_distance(db) + print(existing_distance.distance) + + update_request = EnergyTransmissionDistanceUpdate(**jsonable_encoder(existing_distance)) + update_request.distance = 1234 + + response = client.put( + f"/distances/{existing_distance.id}", + headers=normal_user_headers, + data=update_request.json(), + ) + assert response.status_code == status.HTTP_200_OK + + updated_distance = response.json() + assert updated_distance["distance"] == update_request.distance + + +def test_remove_transmission_distance(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test deleting a transmission distance. + """ + clear_database(db) + scenario = create_transmission_scenario(db) + + # delete the first distance entry + response = client.delete(f"/distances/{scenario['distances'][0].id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + # check that the database only has the remaining distance entries + get_response = client.get("/distances/", headers=normal_user_headers) + assert get_response.status_code == status.HTTP_200_OK + + distance_list = get_response.json() + assert len(distance_list) == 3 + assert_transmission_distance(check_entry=distance_list[0], expected_entry=scenario["distances"][1]) + assert_transmission_distance(check_entry=distance_list[1], expected_entry=scenario["distances"][2]) + assert_transmission_distance(check_entry=distance_list[2], expected_entry=scenario["distances"][3]) + + +def test_remove_transmission_distances_by_component(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test deleting all transmission distances of a component. + """ + clear_database(db) + scenario = create_transmission_scenario(db) + + # delete the distance entries of the first component + component_id = scenario["transmissions"][0].component.id + response = client.delete(f"/distances/component/{component_id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + # check that the database only has distance entries from the second component + get_response = client.get("/distances/", headers=normal_user_headers) + assert get_response.status_code == status.HTTP_200_OK + + distance_list = get_response.json() + assert len(distance_list) == 2 + assert_transmission_distance(check_entry=distance_list[0], expected_entry=scenario["distances"][2]) + assert_transmission_distance(check_entry=distance_list[1], expected_entry=scenario["distances"][3]) diff --git a/tests/api/test_energy_transmission_losses.py b/tests/api/test_energy_transmission_losses.py new file mode 100644 index 0000000..9857234 --- /dev/null +++ b/tests/api/test_energy_transmission_losses.py @@ -0,0 +1,195 @@ +from typing import Dict + +from fastapi import status +from fastapi.encoders import jsonable_encoder +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from ensysmod.schemas import EnergyTransmissionLossUpdate +from tests.utils import data_generator +from tests.utils.assertions import assert_transmission_loss +from tests.utils.data_generator.energy_transmissions import create_transmission_scenario +from tests.utils.utils import clear_database + + +def test_create_transmission_loss(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating a transmission loss. + """ + create_request = data_generator.fixed_transmission_loss_create(db) + response = client.post("/losses/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_200_OK + + created_loss = response.json() + assert created_loss["loss"] == create_request.loss + assert created_loss["transmission"]["component"]["name"] == create_request.component + assert created_loss["region_from"]["name"] == create_request.region_from + assert created_loss["region_to"]["name"] == create_request.region_to + + +def test_create_existing_transmission_loss(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating an existing transmission loss. + """ + clear_database(db) + create_request = data_generator.fixed_transmission_loss_create(db) + response = client.post("/losses/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_200_OK + response = client.post("/losses/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_409_CONFLICT + + +def test_create_transmission_loss_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating a transmission loss with unknown dataset. + """ + create_request = data_generator.fixed_transmission_loss_create(db) + create_request.ref_dataset = 123456 + response = client.post("/losses/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_404_NOT_FOUND + + +def test_create_transmission_loss_unknown_component(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating a transmission loss with unknown component. + """ + create_request = data_generator.fixed_transmission_loss_create(db) + create_request.component = "Unknown Component" + response = client.post("/losses/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_404_NOT_FOUND + + +def test_create_transmission_loss_unknown_regions(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating a transmission loss with unknown regions. + """ + create_request = data_generator.fixed_transmission_loss_create(db) + create_request.region_from = "Unknown Region" + response = client.post("/losses/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_404_NOT_FOUND + + create_request = data_generator.fixed_transmission_loss_create(db) + create_request.region_to = "Unknown Region" + response = client.post("/losses/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_404_NOT_FOUND + + +def test_get_all_transmission_losses(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving all transmission losses. + """ + clear_database(db) + scenario = create_transmission_scenario(db) + + response = client.get("/losses/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + loss_list = response.json() + assert len(loss_list) == 4 + assert_transmission_loss(check_entry=loss_list[0], expected_entry=scenario["losses"][0]) + assert_transmission_loss(check_entry=loss_list[1], expected_entry=scenario["losses"][1]) + assert_transmission_loss(check_entry=loss_list[2], expected_entry=scenario["losses"][2]) + assert_transmission_loss(check_entry=loss_list[3], expected_entry=scenario["losses"][3]) + + +def test_get_transmission_loss(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving a transmission loss. + """ + existing_loss = data_generator.fixed_existing_transmission_loss(db) + response = client.get(f"/losses/{existing_loss.id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + retrieved_loss = response.json() + assert_transmission_loss(check_entry=retrieved_loss, expected_entry=existing_loss) + + +def test_get_transmission_losses_by_component(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving all transmission losses of a component. + """ + clear_database(db) + scenario = create_transmission_scenario(db) + + component_id = scenario["transmissions"][0].component.id + response = client.get(f"/losses/component/{component_id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + loss_list = response.json() + assert len(loss_list) == 2 + assert_transmission_loss(check_entry=loss_list[0], expected_entry=scenario["losses"][0]) + assert_transmission_loss(check_entry=loss_list[1], expected_entry=scenario["losses"][1]) + + component_id = scenario["transmissions"][1].component.id + response = client.get(f"/losses/component/{component_id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + loss_list = response.json() + assert len(loss_list) == 2 + assert_transmission_loss(check_entry=loss_list[0], expected_entry=scenario["losses"][2]) + assert_transmission_loss(check_entry=loss_list[1], expected_entry=scenario["losses"][3]) + + +def test_update_transmission_loss(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): + """ + Test updating a transmission loss. + """ + existing_loss = data_generator.fixed_existing_transmission_loss(db) + print(existing_loss.loss) + + update_request = EnergyTransmissionLossUpdate(**jsonable_encoder(existing_loss)) + update_request.loss = 0.0001234 + + response = client.put( + f"/losses/{existing_loss.id}", + headers=normal_user_headers, + data=update_request.json(), + ) + assert response.status_code == status.HTTP_200_OK + + updated_loss = response.json() + assert updated_loss["loss"] == update_request.loss + + +def test_remove_transmission_loss(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test deleting a transmission loss. + """ + clear_database(db) + scenario = create_transmission_scenario(db) + + # delete the first distance entry + response = client.delete(f"/losses/{scenario['distances'][0].id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + # check that the database only has the remaining distance entries + get_response = client.get("/losses/", headers=normal_user_headers) + assert get_response.status_code == status.HTTP_200_OK + + loss_list = get_response.json() + assert len(loss_list) == 3 + assert_transmission_loss(check_entry=loss_list[0], expected_entry=scenario["losses"][1]) + assert_transmission_loss(check_entry=loss_list[1], expected_entry=scenario["losses"][2]) + assert_transmission_loss(check_entry=loss_list[2], expected_entry=scenario["losses"][3]) + + +def test_remove_transmission_losses_by_component(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test deleting all transmission losses of a component. + """ + clear_database(db) + scenario = create_transmission_scenario(db) + + # delete the distance entries of the first component + component_id = scenario["transmissions"][0].component.id + response = client.delete(f"/losses/component/{component_id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + # check that the database only has distance entries from the second component + get_response = client.get("/losses/", headers=normal_user_headers) + assert get_response.status_code == status.HTTP_200_OK + + loss_list = get_response.json() + assert len(loss_list) == 2 + assert_transmission_loss(check_entry=loss_list[0], expected_entry=scenario["losses"][2]) + assert_transmission_loss(check_entry=loss_list[1], expected_entry=scenario["losses"][3]) diff --git a/tests/utils/assertions.py b/tests/utils/assertions.py index 3b354d0..b459227 100644 --- a/tests/utils/assertions.py +++ b/tests/utils/assertions.py @@ -8,3 +8,19 @@ def assert_energy_component(component: Dict, expected: EnergyComponentCreate, ex assert component["type"] == expected_type.value assert component["name"] == expected.name assert component["description"] == expected.description + + +def assert_transmission_distance(check_entry, expected_entry): + assert check_entry["id"] == expected_entry.id + assert check_entry["distance"] == expected_entry.distance + assert check_entry["transmission"]["component"]["id"] == expected_entry.ref_component + assert check_entry["region_from"]["id"] == expected_entry.ref_region_from + assert check_entry["region_to"]["id"] == expected_entry.ref_region_to + + +def assert_transmission_loss(check_entry, expected_entry): + assert check_entry["id"] == expected_entry.id + assert check_entry["loss"] == expected_entry.loss + assert check_entry["transmission"]["component"]["id"] == expected_entry.ref_component + assert check_entry["region_from"]["id"] == expected_entry.ref_region_from + assert check_entry["region_to"]["id"] == expected_entry.ref_region_to diff --git a/tests/utils/data_generator/__init__.py b/tests/utils/data_generator/__init__.py index 31e0352..1b9d2ab 100644 --- a/tests/utils/data_generator/__init__.py +++ b/tests/utils/data_generator/__init__.py @@ -38,12 +38,16 @@ ) from .energy_transmissions import ( fixed_existing_energy_transmission, + fixed_existing_transmission_distance, + fixed_existing_transmission_loss, + fixed_transmission_distance_create, + fixed_transmission_loss_create, random_energy_transmission_create, random_existing_energy_transmission, ) from .regions import ( - fixed_alternative_existing_region, fixed_existing_region, + fixed_existing_region_alternative, random_existing_region, random_region_create, ) diff --git a/tests/utils/data_generator/energy_transmissions.py b/tests/utils/data_generator/energy_transmissions.py index 8073fef..37b64bd 100644 --- a/tests/utils/data_generator/energy_transmissions.py +++ b/tests/utils/data_generator/energy_transmissions.py @@ -1,19 +1,28 @@ +from typing import Any + from sqlalchemy.orm import Session from ensysmod import crud -from ensysmod.model import EnergyTransmission -from ensysmod.schemas import EnergyTransmissionCreate -from ensysmod.schemas.energy_transmission_distance import ( +from ensysmod.model import ( + EnergyTransmission, + EnergyTransmissionDistance, + EnergyTransmissionLoss, +) +from ensysmod.schemas import ( + EnergyTransmissionCreate, EnergyTransmissionDistanceCreate, + EnergyTransmissionLossCreate, + RegionCreate, ) from tests.utils.data_generator import ( fixed_existing_dataset, fixed_existing_energy_commodity, + random_energy_commodity_create, + random_existing_dataset, ) from tests.utils.data_generator.regions import ( - fixed_alternative_alternative_existing_region, - fixed_alternative_existing_region, fixed_existing_region, + fixed_existing_region_alternative, ) from tests.utils.utils import random_lower_string @@ -21,27 +30,11 @@ def random_energy_transmission_create(db: Session) -> EnergyTransmissionCreate: dataset = fixed_existing_dataset(db) commodity = fixed_existing_energy_commodity(db) - region = fixed_existing_region(db) - region_to = fixed_alternative_existing_region(db) - region_to_alt = fixed_alternative_alternative_existing_region(db) return EnergyTransmissionCreate( ref_dataset=dataset.id, name=f"EnergyTransmission-{dataset.id}-{random_lower_string()}", description="Description", commodity=commodity.name, - loss_per_unit=0.000001, - distances=[ - EnergyTransmissionDistanceCreate( - distance=42.3, - ref_region_from=region.id, - region_to=region_to.name, - ), - EnergyTransmissionDistanceCreate( - distance=44.3, - region_from=region.name, - ref_region_to=region_to_alt.id, - ) - ] ) @@ -53,27 +46,143 @@ def random_existing_energy_transmission(db: Session) -> EnergyTransmission: def fixed_energy_transmission_create(db: Session) -> EnergyTransmissionCreate: dataset = fixed_existing_dataset(db) commodity = fixed_existing_energy_commodity(db) - region = fixed_existing_region(db) - region_to = fixed_existing_region(db) return EnergyTransmissionCreate( ref_dataset=dataset.id, name=f"EnergyTransmission-{dataset.id}-Fixed", description="Description", commodity=commodity.name, - distances=[ - EnergyTransmissionDistanceCreate( - distance=42.3, - ref_region_from=region.id, - ref_region_to=region_to.id, - ) - ] ) def fixed_existing_energy_transmission(db: Session) -> EnergyTransmission: create_request = fixed_energy_transmission_create(db) - transmission = crud.energy_transmission.get_by_dataset_and_name(db=db, dataset_id=create_request.ref_dataset, - name=create_request.name) + transmission = crud.energy_transmission.get_by_dataset_and_name(db=db, dataset_id=create_request.ref_dataset, name=create_request.name) if transmission is None: transmission = crud.energy_transmission.create(db=db, obj_in=create_request) return transmission + + +def fixed_transmission_distance_create(db: Session) -> EnergyTransmissionDistanceCreate: + dataset = fixed_existing_dataset(db) + transmission = fixed_existing_energy_transmission(db) + region_from = fixed_existing_region(db) + region_to = fixed_existing_region_alternative(db) + return EnergyTransmissionDistanceCreate( + distance=1000, + ref_dataset=dataset.id, + component=transmission.component.name, + region_from=region_from.name, + region_to=region_to.name, + ) + + +def fixed_existing_transmission_distance(db: Session) -> EnergyTransmissionDistance: + create_request = fixed_transmission_distance_create(db) + distance = crud.energy_transmission_distance.get_by_dataset_id_component_region_names( + db=db, + dataset_id=create_request.ref_dataset, + component_name=create_request.component, + region_from_name=create_request.region_from, + region_to_name=create_request.region_to, + ) + if distance is None: + distance = crud.energy_transmission_distance.create(db=db, obj_in=create_request) + return distance + + +def fixed_transmission_loss_create(db: Session) -> EnergyTransmissionLossCreate: + dataset = fixed_existing_dataset(db) + transmission = fixed_existing_energy_transmission(db) + region_from = fixed_existing_region(db) + region_to = fixed_existing_region_alternative(db) + return EnergyTransmissionLossCreate( + loss=0.00001, + ref_dataset=dataset.id, + component=transmission.component.name, + region_from=region_from.name, + region_to=region_to.name, + ) + + +def fixed_existing_transmission_loss(db: Session) -> EnergyTransmissionLoss: + create_request = fixed_transmission_loss_create(db) + loss = crud.energy_transmission_loss.get_by_dataset_id_component_region_names( + db=db, + dataset_id=create_request.ref_dataset, + component_name=create_request.component, + region_from_name=create_request.region_from, + region_to_name=create_request.region_to, + ) + if loss is None: + loss = crud.energy_transmission_loss.create(db=db, obj_in=create_request) + return loss + + +def create_transmission_scenario(db: Session) -> dict[str, Any]: + """ + Creates a random dataset, commodity, two transmission components and two regions in the same dataset, + then add entries of transmission distances and losses between those regions. + + Returns a dictionary of the dataset, list of transmission components, list of transmission distances and losses. + """ + dataset = random_existing_dataset(db) + commodity_request = random_energy_commodity_create(db) + commodity_request.ref_dataset = dataset.id + commodity = crud.energy_commodity.create(db=db, obj_in=commodity_request) + + transmission1_request = random_energy_transmission_create(db) + transmission1_request.ref_dataset = dataset.id + transmission1_request.commodity = commodity.name + transmission1 = crud.energy_transmission.create(db=db, obj_in=transmission1_request) + + transmission2_request = random_energy_transmission_create(db) + transmission2_request.ref_dataset = dataset.id + transmission2_request.commodity = commodity.name + transmission2 = crud.energy_transmission.create(db=db, obj_in=transmission2_request) + + region1 = crud.region.create(db=db, obj_in=RegionCreate(name=f"Region1-{random_lower_string()}", ref_dataset=dataset.id)) + region2 = crud.region.create(db=db, obj_in=RegionCreate(name=f"Region2-{random_lower_string()}", ref_dataset=dataset.id)) + + def create_distance_entry(transmission, region_from, region_to, distance): + return crud.energy_transmission_distance.create( + db=db, + obj_in=EnergyTransmissionDistanceCreate( + distance=distance, + ref_dataset=dataset.id, + component=transmission.component.name, + region_from=region_from.name, + region_to=region_to.name, + ), + ) + + def create_loss_entry(transmission, region_from, region_to, loss): + return crud.energy_transmission_loss.create( + db=db, + obj_in=EnergyTransmissionLossCreate( + loss=loss, + ref_dataset=dataset.id, + component=transmission.component.name, + region_from=region_from.name, + region_to=region_to.name, + ), + ) + + distances = [ + create_distance_entry(transmission1, region1, region2, 1000), + create_distance_entry(transmission1, region2, region1, 2000), + create_distance_entry(transmission2, region1, region2, 3000), + create_distance_entry(transmission2, region2, region1, 4000), + ] + losses = [ + create_loss_entry(transmission1, region1, region2, 0.00001), + create_loss_entry(transmission1, region2, region1, 0.00002), + create_loss_entry(transmission2, region1, region2, 0.00003), + create_loss_entry(transmission2, region2, region1, 0.00004), + ] + + return { + "dataset": dataset, + "transmissions": [transmission1, transmission2], + "distances": distances, + "losses": losses, + } diff --git a/tests/utils/data_generator/regions.py b/tests/utils/data_generator/regions.py index 17c72d6..274e403 100644 --- a/tests/utils/data_generator/regions.py +++ b/tests/utils/data_generator/regions.py @@ -9,8 +9,7 @@ def random_region_create(db: Session) -> RegionCreate: dataset = random_existing_dataset(db) - return RegionCreate(name=f"Region-{dataset.id}-{random_lower_string()}", - ref_dataset=dataset.id) + return RegionCreate(name=f"Region-{dataset.id}-{random_lower_string()}", ref_dataset=dataset.id) def random_existing_region(db: Session) -> Region: @@ -20,20 +19,12 @@ def random_existing_region(db: Session) -> Region: def fixed_region_create(db: Session) -> RegionCreate: dataset = fixed_existing_dataset(db) - return RegionCreate(name=f"Region-{dataset.id}-Fixed", - ref_dataset=dataset.id) + return RegionCreate(name=f"Region-{dataset.id}-Fixed", ref_dataset=dataset.id) -def fixed_alternative_region_create(db: Session) -> RegionCreate: +def fixed_region_alternative_create(db: Session) -> RegionCreate: dataset = fixed_existing_dataset(db) - return RegionCreate(name=f"Region-{dataset.id}-Fixed-alternative", - ref_dataset=dataset.id) - - -def fixed_alternative_alternative_region_create(db: Session) -> RegionCreate: - dataset = fixed_existing_dataset(db) - return RegionCreate(name=f"Region-{dataset.id}-Fixed-alternative-alternative", - ref_dataset=dataset.id) + return RegionCreate(name=f"Region-{dataset.id}-Fixed-alternative", ref_dataset=dataset.id) def fixed_existing_region(db: Session) -> Region: @@ -44,16 +35,8 @@ def fixed_existing_region(db: Session) -> Region: return region -def fixed_alternative_existing_region(db: Session) -> Region: - create_request = fixed_alternative_region_create(db) - region = crud.region.get_by_dataset_and_name(db=db, dataset_id=create_request.ref_dataset, name=create_request.name) - if region is None: - region = crud.region.create(db=db, obj_in=create_request) - return region - - -def fixed_alternative_alternative_existing_region(db: Session) -> Region: - create_request = fixed_alternative_alternative_region_create(db) +def fixed_existing_region_alternative(db: Session) -> Region: + create_request = fixed_region_alternative_create(db) region = crud.region.get_by_dataset_and_name(db=db, dataset_id=create_request.ref_dataset, name=create_request.name) if region is None: region = crud.region.create(db=db, obj_in=create_request) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 1d53d8e..ad2bb2a 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -24,6 +24,7 @@ EnergyStorage, EnergyTransmission, EnergyTransmissionDistance, + EnergyTransmissionLoss, OperationRateFix, OperationRateMax, Region, @@ -100,6 +101,7 @@ def clear_database(db: Session): EnergyStorage, EnergyTransmission, EnergyTransmissionDistance, + EnergyTransmissionLoss, OperationRateFix, OperationRateMax, Region, diff --git a/tests/validators/test_component_or_ref_validator.py b/tests/validators/test_component_or_ref_validator.py deleted file mode 100644 index ce13d96..0000000 --- a/tests/validators/test_component_or_ref_validator.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Type, List, Tuple, Dict, Any - -import pytest -from pydantic import BaseModel, ValidationError - -schemas_with_component_or_ref_required: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [] - -schemas_with_component_or_ref_optional: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [] - -schemas_with_component_or_ref = schemas_with_component_or_ref_required + schemas_with_component_or_ref_optional - - -@pytest.mark.parametrize("schema,data", schemas_with_component_or_ref_required) -def test_error_missing_component_or_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a component_or_ref is required for a schema - """ - with pytest.raises(ValidationError) as exc_info: - schema(**data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("component_or_ref",) - assert exc_info.value.errors()[0]["msg"] == "field required" - assert exc_info.value.errors()[0]["type"] == "value_error.missing" - - -@pytest.mark.parametrize("schema,data", schemas_with_component_or_ref_optional) -def test_ok_missing_component_or_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a component_or_ref is optional for a schema - """ - schema(**data) - - -@pytest.mark.parametrize("schema,data", schemas_with_component_or_ref) -def test_error_on_negative_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a component_or_ref is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(ref_component=-1, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "Reference to a component must be positive." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_component_or_ref) -def test_error_on_zero_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a component_or_ref is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(ref_component=0, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "Reference to a component must be positive." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_component_or_ref) -def test_error_on_long_component(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a component_or_ref is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(component='a' * 101, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "The component must not be longer than 100 characters." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_component_or_ref) -def test_ok_component_or_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a component_or_ref with everything over 0 is valid - """ - schema(component='a', **data) - schema(component='a' * 100, **data) - schema(ref_component=1, **data) diff --git a/tests/validators/test_distance_validator.py b/tests/validators/test_distance_validator.py index 44be8d5..65d33f0 100644 --- a/tests/validators/test_distance_validator.py +++ b/tests/validators/test_distance_validator.py @@ -1,22 +1,39 @@ -from typing import Type, List, Tuple, Dict, Any +from typing import Any, Dict, List, Tuple, Type import pytest from pydantic import BaseModel, ValidationError -from ensysmod.schemas.energy_transmission_distance import EnergyTransmissionDistanceCreate, \ - EnergyTransmissionDistanceUpdate +from ensysmod.schemas import ( + EnergyTransmissionDistanceCreate, + EnergyTransmissionDistanceUpdate, +) schemas_with_distance_required: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [ - (EnergyTransmissionDistanceCreate, {"ref_region_from": 42, "ref_region_to": 1337}) + (EnergyTransmissionDistanceCreate, {"ref_dataset": 1, "component": "test", "region_from": "Region 1", "region_to": "Region 2"}), + (EnergyTransmissionDistanceUpdate, {}), ] -schemas_with_distance_optional: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [ - (EnergyTransmissionDistanceUpdate, {}) -] +schemas_with_distance_optional: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [] schemas_with_distance = schemas_with_distance_required + schemas_with_distance_optional +@pytest.mark.parametrize("schema,data", schemas_with_distance_optional) +def test_ok_missing_distance(schema: Type[BaseModel], data: Dict[str, Any]): + """ + Test that a distance is optional for a schema + """ + schema(**data) + + +@pytest.mark.parametrize("schema,data", schemas_with_distance_optional) +def test_ok_none_distance(schema: Type[BaseModel], data: Dict[str, Any]): + """ + Test that a distance is optional for a schema + """ + schema(distance=None, **data) + + @pytest.mark.parametrize("schema,data", schemas_with_distance_required) def test_error_missing_distance(schema: Type[BaseModel], data: Dict[str, Any]): """ @@ -31,21 +48,13 @@ def test_error_missing_distance(schema: Type[BaseModel], data: Dict[str, Any]): assert exc_info.value.errors()[0]["type"] == "value_error.missing" -@pytest.mark.parametrize("schema,data", schemas_with_distance_optional) -def test_ok_missing_distance(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a distance is optional for a schema - """ - schema(**data) - - @pytest.mark.parametrize("schema,data", schemas_with_distance) -def test_error_on_negative_distance(schema: Type[BaseModel], data: Dict[str, Any]): +def test_error_negative_distance(schema: Type[BaseModel], data: Dict[str, Any]): """ - Test that a distance is not under zero + Test that a distance is not negative """ with pytest.raises(ValidationError) as exc_info: - schema(distance=-0.5, **data) + schema(distance=-1, **data) assert len(exc_info.value.errors()) == 1 assert exc_info.value.errors()[0]["loc"] == ("distance",) @@ -54,9 +63,9 @@ def test_error_on_negative_distance(schema: Type[BaseModel], data: Dict[str, Any @pytest.mark.parametrize("schema,data", schemas_with_distance) -def test_ok_distance(schema: Type[BaseModel], data: Dict[str, Any]): +def test_ok_distances(schema: Type[BaseModel], data: Dict[str, Any]): """ - Test that a distance with everything over 0 is valid + Test that a zero or positive distance is valid """ + schema(distance=1000, **data) schema(distance=0, **data) - schema(distance=1, **data) diff --git a/tests/validators/test_distances_validator.py b/tests/validators/test_distances_validator.py deleted file mode 100644 index 448de50..0000000 --- a/tests/validators/test_distances_validator.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Type, List, Tuple, Dict, Any - -import pytest -from pydantic import BaseModel, ValidationError - -from ensysmod.model import EnergyComponentType -from ensysmod.schemas.energy_transmission import EnergyTransmissionCreate -from ensysmod.schemas.energy_transmission_distance import EnergyTransmissionDistanceCreate - -schemas_with_distances_required: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [ - (EnergyTransmissionCreate, {"name": "test", "description": "bar", "ref_region_from": 42, "ref_region_to": 1337, - "ref_dataset": 42, "type": EnergyComponentType.TRANSMISSION, "commodity": "bar"}) -] - -schemas_with_distances_optional: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [] - -schemas_with_distances = schemas_with_distances_required + schemas_with_distances_optional - - -@pytest.mark.parametrize("schema,data", schemas_with_distances_optional) -def test_ok_missing_distances(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a distances is optional for a schema - """ - schema(**data) - - -@pytest.mark.parametrize("schema,data", schemas_with_distances_optional) -def test_error_empty_distances(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a distances is optional for a schema - """ - with pytest.raises(ValidationError) as exc_info: - schema(distances=[], **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("distances",) - assert exc_info.value.errors()[0]["msg"] == "List of distances must not be empty." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_distances) -def test_ok_distances(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a distances with everything over 0 is valid - """ - schema(distances=[EnergyTransmissionDistanceCreate(distance=5, ref_region_to=4, ref_region_from=3)], **data) diff --git a/tests/validators/test_loss_per_unit_validator.py b/tests/validators/test_loss_per_unit_validator.py deleted file mode 100644 index 00a138d..0000000 --- a/tests/validators/test_loss_per_unit_validator.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Type, List, Tuple, Dict, Any - -import pytest -from pydantic import BaseModel, ValidationError - -from ensysmod.model import EnergyComponentType -from ensysmod.schemas.energy_transmission import EnergyTransmissionCreate, EnergyTransmissionUpdate - -schemas_with_loss_per_unit_required: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [] - -schemas_with_loss_per_unit_optional: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [ - (EnergyTransmissionUpdate, {}), - (EnergyTransmissionCreate, - {"name": "test", "ref_dataset": 42, "type": EnergyComponentType.TRANSMISSION, "commodity": "bar"}) -] - -schemas_with_loss_per_unit = schemas_with_loss_per_unit_required + schemas_with_loss_per_unit_optional - - -@pytest.mark.parametrize("schema,data", schemas_with_loss_per_unit_optional) -def test_ok_missing_loss_per_unit(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a loss per unit is optional for a schema - """ - schema(**data) - - -@pytest.mark.parametrize("schema,data", schemas_with_loss_per_unit_optional) -def test_ok_none_loss_per_unit(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a loss per unit is optional for a schema - """ - schema(loss_per_unit=None, **data) - - -@pytest.mark.parametrize("schema,data", schemas_with_loss_per_unit) -def test_error_on_negative_loss_per_unit(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a loss per unit is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(loss_per_unit=-0.5, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("loss_per_unit",) - assert exc_info.value.errors()[0]["msg"] == "The loss per unit must be zero or positive." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_loss_per_unit) -def test_error_on_positive_loss_per_unit(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a loss per unit is not over 1 - """ - with pytest.raises(ValidationError) as exc_info: - schema(loss_per_unit=1.5, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("loss_per_unit",) - assert exc_info.value.errors()[0]["msg"] == "The loss per unit must be zero or positive." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_loss_per_unit) -def test_ok_loss_per_units(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a loss per unit with everything over 0 is valid - """ - schema(loss_per_unit=0, **data) - schema(loss_per_unit=0.5, **data) - schema(loss_per_unit=1, **data) diff --git a/tests/validators/test_loss_validator.py b/tests/validators/test_loss_validator.py new file mode 100644 index 0000000..3993218 --- /dev/null +++ b/tests/validators/test_loss_validator.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Tuple, Type + +import pytest +from pydantic import BaseModel, ValidationError + +from ensysmod.schemas import EnergyTransmissionLossCreate, EnergyTransmissionLossUpdate + +schemas_with_loss_required: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [ + (EnergyTransmissionLossCreate, {"ref_dataset": 1, "component": "test", "region_from": "Region 1", "region_to": "Region 2"}), + (EnergyTransmissionLossUpdate, {}), +] + +schemas_with_loss_optional: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [] + +schemas_with_loss = schemas_with_loss_required + schemas_with_loss_optional + + +@pytest.mark.parametrize("schema,data", schemas_with_loss_optional) +def test_ok_missing_loss(schema: Type[BaseModel], data: Dict[str, Any]): + """ + Test that a loss is optional for a schema + """ + schema(**data) + + +@pytest.mark.parametrize("schema,data", schemas_with_loss_optional) +def test_ok_none_loss(schema: Type[BaseModel], data: Dict[str, Any]): + """ + Test that a loss is optional for a schema + """ + schema(loss=None, **data) + + +@pytest.mark.parametrize("schema,data", schemas_with_loss_required) +def test_error_missing_loss(schema: Type[BaseModel], data: Dict[str, Any]): + """ + Test that a loss is required for a schema + """ + with pytest.raises(ValidationError) as exc_info: + schema(**data) + + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["loc"] == ("loss",) + assert exc_info.value.errors()[0]["msg"] == "field required" + assert exc_info.value.errors()[0]["type"] == "value_error.missing" + + +@pytest.mark.parametrize("schema,data", schemas_with_loss) +def test_error_negative_loss(schema: Type[BaseModel], data: Dict[str, Any]): + """ + Test that a loss is not negative + """ + with pytest.raises(ValidationError) as exc_info: + schema(loss=-1, **data) + + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["loc"] == ("loss",) + assert exc_info.value.errors()[0]["msg"] == "The loss must be between zero and one." + assert exc_info.value.errors()[0]["type"] == "value_error" + + +@pytest.mark.parametrize("schema,data", schemas_with_loss) +def test_error_loss_above_one(schema: Type[BaseModel], data: Dict[str, Any]): + """ + Test that a loss is not above one + """ + with pytest.raises(ValidationError) as exc_info: + schema(loss=1.01, **data) + + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["loc"] == ("loss",) + assert exc_info.value.errors()[0]["msg"] == "The loss must be between zero and one." + assert exc_info.value.errors()[0]["type"] == "value_error" + + +@pytest.mark.parametrize("schema,data", schemas_with_loss) +def test_ok_losses(schema: Type[BaseModel], data: Dict[str, Any]): + """ + Test that a loss is between zero and one valid + """ + schema(loss=0, **data) + schema(loss=0.5, **data) + schema(loss=1, **data) diff --git a/tests/validators/test_ref_dataset_validator.py b/tests/validators/test_ref_dataset_validator.py index 9ef83fc..54b4a92 100644 --- a/tests/validators/test_ref_dataset_validator.py +++ b/tests/validators/test_ref_dataset_validator.py @@ -1,24 +1,30 @@ -from typing import Type, List, Tuple, Dict, Any +from typing import Any, Dict, List, Tuple, Type import pytest from pydantic import BaseModel, ValidationError from ensysmod.model.energy_component import EnergyComponentType -from ensysmod.schemas import EnergyCommodityCreate, EnergyComponentCreate, EnergyConversionFactorCreate +from ensysmod.schemas import ( + EnergyCommodityCreate, + EnergyComponentCreate, + EnergyConversionFactorCreate, + EnergyTransmissionDistanceCreate, + EnergyTransmissionLossCreate, +) from ensysmod.schemas.energy_model import EnergyModelCreate -from ensysmod.schemas.energy_transmission_distance import EnergyTransmissionDistanceCreate from ensysmod.schemas.region import RegionCreate schemas_with_ref_dataset_required: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [ (EnergyCommodityCreate, {"name": "test", "description": "foo", "unit": "bar"}), (EnergyComponentCreate, {"name": "test", "description": "foo", "type": EnergyComponentType.SOURCE}), (EnergyModelCreate, {"name": "test"}), - (RegionCreate, {"name": "test"}) + (RegionCreate, {"name": "test"}), + (EnergyTransmissionDistanceCreate, {"distance": 1000, "component": "test", "region_from": "Region 1", "region_to": "Region 2"}), + (EnergyTransmissionLossCreate, {"loss": 0.00001, "component": "test", "region_from": "Region 1", "region_to": "Region 2"}), ] schemas_with_ref_dataset_optional: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [ (EnergyConversionFactorCreate, {"conversion_factor": 4.2, "commodity": "bar"}), - (EnergyTransmissionDistanceCreate, {"distance": 5, "ref_region_from": 42, "ref_region_to": 1337}) ] schemas_with_ref_dataset = schemas_with_ref_dataset_required + schemas_with_ref_dataset_optional diff --git a/tests/validators/test_region_from_or_ref_validator.py b/tests/validators/test_region_from_or_ref_validator.py deleted file mode 100644 index bc0f6b8..0000000 --- a/tests/validators/test_region_from_or_ref_validator.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Type, List, Tuple, Dict, Any - -import pytest -from pydantic import BaseModel, ValidationError - -from ensysmod.schemas.energy_transmission_distance import EnergyTransmissionDistanceCreate - -schemas_with_region_from_or_ref_required: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [ - (EnergyTransmissionDistanceCreate, {"distance": 4, "ref_region_to": 1337}) -] - -schemas_with_region_from_or_ref_optional: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [] - -schemas_with_region_from_or_ref = schemas_with_region_from_or_ref_required + schemas_with_region_from_or_ref_optional - - -@pytest.mark.parametrize("schema,data", schemas_with_region_from_or_ref_required) -def test_error_missing_region_from_or_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_from_or_ref is required for a schema - """ - with pytest.raises(ValidationError) as exc_info: - schema(**data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "Either region_from or ref_region_from must be provided." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_region_from_or_ref_optional) -def test_ok_missing_region_from_or_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_from_or_ref is optional for a schema - """ - schema(**data) - - -@pytest.mark.parametrize("schema,data", schemas_with_region_from_or_ref) -def test_error_on_negative_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_from_or_ref is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(ref_region_from=-1, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "Reference to the region_from must be positive." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_region_from_or_ref) -def test_error_on_zero_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_from_or_ref is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(ref_region_from=0, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "Reference to the region_from must be positive." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_region_from_or_ref) -def test_error_on_long_region_from(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_from_or_ref is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(region_from='a' * 101, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "The region_from must not be longer than 100 characters." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_region_from_or_ref) -def test_ok_region_from_or_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_from_or_ref with everything over 0 is valid - """ - schema(region_from='a', **data) - schema(region_from='a' * 100, **data) - schema(ref_region_from=1, **data) diff --git a/tests/validators/test_region_to_or_ref_validator.py b/tests/validators/test_region_to_or_ref_validator.py deleted file mode 100644 index 53e962f..0000000 --- a/tests/validators/test_region_to_or_ref_validator.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Type, List, Tuple, Dict, Any - -import pytest -from pydantic import BaseModel, ValidationError - -from ensysmod.schemas.energy_transmission_distance import EnergyTransmissionDistanceCreate - -schemas_with_region_to_or_ref_required: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [ - (EnergyTransmissionDistanceCreate, {"distance": 4, "ref_region_from": 42}) -] - -schemas_with_region_to_or_ref_optional: List[Tuple[Type[BaseModel], Dict[str, Any]]] = [] - -schemas_with_region_to_or_ref = schemas_with_region_to_or_ref_required + schemas_with_region_to_or_ref_optional - - -@pytest.mark.parametrize("schema,data", schemas_with_region_to_or_ref_required) -def test_error_missing_region_to_or_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_to_or_ref is required for a schema - """ - with pytest.raises(ValidationError) as exc_info: - schema(**data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "Either region_to or ref_region_to must be provided." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_region_to_or_ref_optional) -def test_ok_missing_region_to_or_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_to_or_ref is optional for a schema - """ - schema(**data) - - -@pytest.mark.parametrize("schema,data", schemas_with_region_to_or_ref) -def test_error_on_negative_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_to_or_ref is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(ref_region_to=-1, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "Reference to the region_to must be positive." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_region_to_or_ref) -def test_error_on_zero_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_to_or_ref is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(ref_region_to=0, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "Reference to the region_to must be positive." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_region_to_or_ref) -def test_error_on_long_region_to(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_to_or_ref is not under zero - """ - with pytest.raises(ValidationError) as exc_info: - schema(region_to='a' * 101, **data) - - assert len(exc_info.value.errors()) == 1 - assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "The region_to must not be longer than 100 characters." - assert exc_info.value.errors()[0]["type"] == "value_error" - - -@pytest.mark.parametrize("schema,data", schemas_with_region_to_or_ref) -def test_ok_region_to_or_ref(schema: Type[BaseModel], data: Dict[str, Any]): - """ - Test that a region_to_or_ref with everything over 0 is valid - """ - schema(region_to='a', **data) - schema(region_to='a' * 100, **data) - schema(ref_region_to=1, **data)