diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ed62347..9905a9d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -86,6 +86,7 @@ jobs: - name: Install requirements run: | + sudo apt-get update && sudo apt-get install --no-install-recommends -y glpk-utils python3-swiglpk && sudo rm -rf /var/lib/apt/lists/* pip install --upgrade --upgrade-strategy eager -r requirements-dev.txt -e . - name: Run tests @@ -102,7 +103,7 @@ jobs: result.xml - uplaod_coverage_results: + upload_coverage_results: needs: test runs-on: ubuntu-latest name: "Upload code coverage" @@ -217,66 +218,3 @@ jobs: - name: Push the Docker image to GitHub Container Registry run: | docker push ${{ steps.get_tag.outputs.DOCKER_TAG }} - - - deploy_dev: - needs: [ build_dev_image ] - runs-on: ubuntu-latest - concurrency: ssh-connection # only one ssh connection at a time - name: "Deploy dev image" - if: ${{ success() && github.actor != 'dependabot[bot]' }} - steps: - - name: Update deployment status - start - uses: bobheadxi/deployments@v1.4.0 - id: deployment - with: - step: start - token: ${{ github.token }} - env: Development - no_override: false - desc: "Development deployment for main branch" - ref: "main" # dev deployment of main branch - transient: true - - - name: Install VPN - run: | - sudo /sbin/modprobe tun - sudo apt install openconnect - - - name: Connect VPN - run: | - echo "${{ secrets.VPN_PASS }}" | sudo openconnect ${{ secrets.VPN_URL }} --background --user=${{ secrets.VPN_USER }} --passwd-on-stdin - - - name: Deploy docker container on private server - uses: appleboy/ssh-action@v0.1.4 - with: - host: ${{ secrets.SSH_URL }} - username: ${{ secrets.SSH_USER }} - password: ${{ secrets.SSH_PASS }} - script: | - docker system prune -af - docker pull ${{ needs.build_dev_image.outputs.image_tag }} - docker ps --filter publish=9000 - docker rm -f $(docker ps --filter publish=9000 -aq) - docker run -d -p 9000:8080 --name "dev" ${{ needs.build_dev_image.outputs.image_tag }} - - - name: Disconnect VPN - if: ${{ always() }} - run: | - sudo pkill openconnect - - - name: Get env url - id: get_env_url - run: | - ENV_URL="http://${{ secrets.SSH_URL }}:9000" - echo ::set-output name=ENV_URL::"${ENV_URL}" - - - name: Update deployment status - finish - uses: bobheadxi/deployments@v1.4.0 - if: always() - with: - step: finish - token: ${{ github.token }} - status: ${{ job.status }} - deployment_id: ${{ steps.deployment.outputs.deployment_id }} - env_url: ${{ steps.get_env_url.outputs.env_url }} diff --git a/.github/workflows/pull-request-done.yml b/.github/workflows/pull-request-done.yml index b59b85a..3fe5a4a 100644 --- a/.github/workflows/pull-request-done.yml +++ b/.github/workflows/pull-request-done.yml @@ -40,46 +40,6 @@ jobs: MATRIX_CONTEXT: ${{ toJSON(matrix) }} run: echo "$MATRIX_CONTEXT" - preview_delete: - runs-on: ubuntu-latest - concurrency: ssh-connection # only one ssh connection at a time - name: "Delete preview" - if: ${{ github.actor != 'dependabot[bot]' }} - steps: - - name: Update deployment status - deactivate - uses: bobheadxi/deployments@v1.4.0 - id: deactivate - with: - step: deactivate-env - token: ${{ github.token }} - env: PR-${{ github.event.number }}-Preview - desc: "Preview deployment for PR #${{ github.event.number }} was pruned." - - - name: Install VPN - run: | - sudo /sbin/modprobe tun - sudo apt install openconnect - - - name: Connect VPN - run: | - echo "${{ secrets.VPN_PASS }}" | sudo openconnect ${{ secrets.VPN_URL }} --background --user=${{ secrets.VPN_USER }} --passwd-on-stdin - - - name: Stop docker container on private server - uses: appleboy/ssh-action@v0.1.4 - with: - host: ${{ secrets.SSH_URL }} - username: ${{ secrets.SSH_USER }} - password: ${{ secrets.SSH_PASS }} - script: | - docker ps --filter publish=$((9000 + ${{ github.event.number }})) - docker rm -f $(docker ps --filter publish=$((9000 + ${{ github.event.number }})) -aq) > /dev/null || true - - - name: Disconnect VPN - if: ${{ always() }} - run: | - sudo pkill openconnect - - create_release: runs-on: "ubuntu-latest" if: github.event.pull_request.merged == true && startsWith( github.head_ref, 'release/') @@ -180,65 +140,3 @@ jobs: run: | docker push ${{ steps.get_tag.outputs.DOCKER_TAG }} - - deploy_prod: - needs: [ create_release, build_prod_image ] - runs-on: ubuntu-latest - concurrency: ssh-connection # only one ssh connection at a time - name: "Deploy production image" - if: ${{ success() && github.actor != 'dependabot[bot]' }} - steps: - - name: Update deployment status - start - uses: bobheadxi/deployments@v1.4.0 - id: deployment - with: - step: start - token: ${{ github.token }} - env: Production - no_override: false - desc: "Production deployment for latest release" - ref: "v${{ needs.create_release.outputs.version }}" # tag of current release - transient: true - - - name: Install VPN - run: | - sudo /sbin/modprobe tun - sudo apt install openconnect - - - name: Connect VPN - run: | - echo "${{ secrets.VPN_PASS }}" | sudo openconnect ${{ secrets.VPN_URL }} --background --user=${{ secrets.VPN_USER }} --passwd-on-stdin - - - name: Deploy docker container on private server - uses: appleboy/ssh-action@v0.1.4 - with: - host: ${{ secrets.SSH_URL }} - username: ${{ secrets.SSH_USER }} - password: ${{ secrets.SSH_PASS }} - script: | - docker system prune -af - docker pull ${{ needs.build_prod_image.outputs.image_tag }} - docker ps --filter publish=8080 - docker rm -f $(docker ps --filter publish=8080 -aq) - docker run -d -p 8080:8080 --name "production" ${{ needs.build_prod_image.outputs.image_tag }} - - - name: Disconnect VPN - if: ${{ always() }} - run: | - sudo pkill openconnect - - - name: Get env url - id: get_env_url - run: | - ENV_URL="http://${{ secrets.SSH_URL }}:8080" - echo ::set-output name=ENV_URL::"${ENV_URL}" - - - name: Update deployment status - finish - uses: bobheadxi/deployments@v1.4.0 - if: always() - with: - step: finish - token: ${{ github.token }} - status: ${{ job.status }} - deployment_id: ${{ steps.deployment.outputs.deployment_id }} - env_url: ${{ steps.get_env_url.outputs.env_url }} \ No newline at end of file diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index cc64d77..105bab1 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -91,6 +91,7 @@ jobs: - name: Install requirements run: | + sudo apt-get update && sudo apt-get install --no-install-recommends -y glpk-utils python3-swiglpk && sudo rm -rf /var/lib/apt/lists/* pip install --upgrade --upgrade-strategy eager -r requirements-dev.txt -e . - name: Run tests @@ -107,7 +108,7 @@ jobs: result.xml - uplaod_coverage_results: + upload_coverage_results: needs: test runs-on: ubuntu-latest name: "Upload code coverage" @@ -183,67 +184,3 @@ jobs: - name: Push the Docker image to GitHub Container Registry run: | docker push ${{ steps.get_tag.outputs.DOCKER_TAG }} - - - deploy_pr: - needs: [ build_pr_image ] - runs-on: ubuntu-latest - concurrency: ssh-connection # only one ssh connection at a time - name: "Deploy preview image" - if: ${{ success() && github.actor != 'dependabot[bot]' }} - steps: - - name: Update deployment status - start - uses: bobheadxi/deployments@v1.4.0 - id: deployment - with: - step: start - token: ${{ github.token }} - env: PR-${{ github.event.number }}-Preview - no_override: false - desc: "Preview deployment for PR #${{ github.event.number }}" - ref: ${{ github.head_ref }} - transient: true - - - name: Install VPN - run: | - sudo /sbin/modprobe tun - sudo apt install openconnect - - - name: Connect VPN - run: | - echo "${{ secrets.VPN_PASS }}" | sudo openconnect ${{ secrets.VPN_URL }} --background --user=${{ secrets.VPN_USER }} --passwd-on-stdin - - - name: Deploy docker container on private server - uses: appleboy/ssh-action@v0.1.4 - with: - host: ${{ secrets.SSH_URL }} - username: ${{ secrets.SSH_USER }} - password: ${{ secrets.SSH_PASS }} - script: | - docker system prune -af - docker pull ${{ needs.build_pr_image.outputs.image_tag }} - docker ps --filter publish=$((9000 + ${{ github.event.number }})) - docker rm -f $(docker ps --filter publish=$((9000 + ${{ github.event.number }})) -aq) - docker run -d -p $((9000 + ${{ github.event.number }})):8080 --name "pr-preview-$((9000 + ${{ github - .event.number }}))" ${{ needs.build_pr_image.outputs.image_tag }} - - - name: Disconnect VPN - if: ${{ always() }} - run: | - sudo pkill openconnect - - - name: Get env url - id: get_env_url - run: | - ENV_URL="http://${{ secrets.SSH_URL }}:$((9000 + ${{ github.event.number }} ))" - echo ::set-output name=ENV_URL::"${ENV_URL}" - - - name: Update deployment status - finish - uses: bobheadxi/deployments@v1.4.0 - if: always() - with: - step: finish - token: ${{ github.token }} - status: ${{ job.status }} - deployment_id: ${{ steps.deployment.outputs.deployment_id }} - env_url: ${{ steps.get_env_url.outputs.env_url }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 1c621cf..d71ada2 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +generated/ # Translations *.mo diff --git a/ensysmod/api/api.py b/ensysmod/api/api.py index 7126084..86ee061 100644 --- a/ensysmod/api/api.py +++ b/ensysmod/api/api.py @@ -1,8 +1,23 @@ from fastapi import APIRouter -from .endpoints import users, authentication, energy_sources, datasets, energy_commodities, energy_sinks, \ - energy_storages, energy_transmissions, energy_conversions, regions, ts_capacity_max, ts_operation_rate_fix, \ - ts_operation_rate_max, ts_capacity_fix, energy_models, datasets_permissions +from .endpoints import ( + authentication, + datasets, + datasets_permissions, + energy_commodities, + energy_conversions, + energy_models, + energy_sinks, + energy_sources, + energy_storages, + energy_transmissions, + regions, + ts_capacity_fix, + ts_capacity_max, + ts_operation_rate_fix, + ts_operation_rate_max, + users, +) api_router = APIRouter() api_router.include_router(authentication.router, prefix="/auth", tags=["Authentication"]) @@ -18,7 +33,7 @@ api_router.include_router(energy_transmissions.router, prefix="/transmissions", tags=["Energy Transmissions"]) api_router.include_router(energy_models.router, prefix="/models", tags=["Energy Models"]) -api_router.include_router(ts_capacity_fix.router, prefix="/fix-capacities", tags=["TS Capacities Fix"]) -api_router.include_router(ts_capacity_max.router, prefix="/max-capacities", tags=["TS Capacities Max"]) -api_router.include_router(ts_operation_rate_fix.router, prefix="/fix-operation-rates", tags=["TS Operation Rates Fix"]) -api_router.include_router(ts_operation_rate_max.router, prefix="/max-operation-rates", tags=["TS Operation Rates Max"]) +api_router.include_router(ts_capacity_fix.router, prefix="/fix-capacities", tags=["Fix Capacities"]) +api_router.include_router(ts_capacity_max.router, prefix="/max-capacities", tags=["Max Capacities"]) +api_router.include_router(ts_operation_rate_fix.router, prefix="/fix-operation-rates", tags=["Fix Operation Rates"]) +api_router.include_router(ts_operation_rate_max.router, prefix="/max-operation-rates", tags=["Max Operation Rates"]) diff --git a/ensysmod/api/endpoints/datasets.py b/ensysmod/api/endpoints/datasets.py index 79dc1d0..a82bc5e 100644 --- a/ensysmod/api/endpoints/datasets.py +++ b/ensysmod/api/endpoints/datasets.py @@ -4,12 +4,12 @@ from io import BytesIO from typing import List -from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi.encoders import jsonable_encoder from fastapi.responses import FileResponse from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions from ensysmod.core.file_download import export_data from ensysmod.core.file_upload import process_dataset_zip_archive @@ -19,10 +19,10 @@ @router.get("/", response_model=List[schemas.Dataset]) -def all_datasets(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.Dataset]: +def get_all_datasets(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.Dataset]: """ Retrieve all datasets. """ @@ -36,7 +36,7 @@ def get_dataset(dataset_id: int, """ Retrieve a dataset. """ - return crud.dataset.get(db, dataset_id) + return crud.dataset.get(db=db, id=dataset_id) @router.post("/", response_model=schemas.Dataset, @@ -83,10 +83,13 @@ def remove_dataset(dataset_id: int, @router.post("/{dataset_id}/upload", response_model=schemas.ZipArchiveUploadResult) -def upload_zip_archive(dataset_id: int, +def upload_dataset_zip(dataset_id: int, file: UploadFile = File(...), db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): + """ + Upload a dataset as zip. + """ if file.content_type not in ["application/x-zip-compressed", "application/zip", "application/zip-compressed"]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"File must be a zip archive. You provided {file.content_type}!") @@ -107,11 +110,11 @@ def upload_zip_archive(dataset_id: int, @router.get("/{dataset_id}/download") -def download_zip_archive(dataset_id: int, +def download_dataset_zip(dataset_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ - Downloads the dataset as zip + Download a dataset as zip. """ dataset = crud.dataset.get(db=db, id=dataset_id) if dataset is None: diff --git a/ensysmod/api/endpoints/datasets_permissions.py b/ensysmod/api/endpoints/datasets_permissions.py index d811f2a..a84b188 100644 --- a/ensysmod/api/endpoints/datasets_permissions.py +++ b/ensysmod/api/endpoints/datasets_permissions.py @@ -3,19 +3,19 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps router = APIRouter() @router.get("/", response_model=List[schemas.DatasetPermission]) -def all_dataset_permission(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - dataset_id: int = 0, - user_id: int = 0, - skip: int = 0, - limit: int = 100) -> List[schemas.DatasetPermission]: +def get_all_dataset_permissions(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + dataset_id: int = 0, + user_id: int = 0, + skip: int = 0, + limit: int = 100) -> List[schemas.DatasetPermission]: """ Retrieve all dataset permissions. diff --git a/ensysmod/api/endpoints/energy_commodities.py b/ensysmod/api/endpoints/energy_commodities.py index 3d8f371..73da1b1 100644 --- a/ensysmod/api/endpoints/energy_commodities.py +++ b/ensysmod/api/endpoints/energy_commodities.py @@ -1,27 +1,27 @@ -from typing import List, Union +from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions router = APIRouter() @router.get("/", response_model=List[schemas.EnergyCommodity]) -def all_commodities(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100, - dataset: Union[None, int] = None) -> List[schemas.EnergyCommodity]: +def get_all_commodities(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100, + dataset_id: Optional[int] = None) -> List[schemas.EnergyCommodity]: """ Retrieve all energy commodities. """ - if dataset is None: - return crud.energy_commodity.get_multi(db, skip=skip, limit=limit) + if dataset_id is None: + return crud.energy_commodity.get_multi(db=db, skip=skip, limit=limit) else: - return crud.energy_commodity.get_multi_by_dataset(db, dataset_id=dataset, skip=skip, limit=limit) + return crud.energy_commodity.get_multi_by_dataset(db=db, skip=skip, limit=limit, dataset_id=dataset_id) @router.get("/{commodity_id}", response_model=schemas.EnergyCommodity) @@ -29,9 +29,9 @@ def get_commodity(commodity_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ - Retrieve a energy commodity. + Retrieve an energy commodity. """ - return crud.energy_commodity.get(db, commodity_id) + return crud.energy_commodity.get(db=db, id=commodity_id) @router.post("/", response_model=schemas.EnergyCommodity, @@ -62,7 +62,7 @@ def update_commodity(commodity_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ - Update a energy commodity. + Update an energy commodity. """ commodity = crud.energy_commodity.get(db=db, id=commodity_id) if commodity is None: @@ -76,7 +76,7 @@ def remove_commodity(commodity_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ - Delete a energy commodity. + Delete an energy commodity. """ commodity = crud.energy_commodity.get(db=db, id=commodity_id) if commodity is None: diff --git a/ensysmod/api/endpoints/energy_conversions.py b/ensysmod/api/endpoints/energy_conversions.py index 19733af..5d7c440 100644 --- a/ensysmod/api/endpoints/energy_conversions.py +++ b/ensysmod/api/endpoints/energy_conversions.py @@ -1,24 +1,23 @@ from typing import List -from fastapi import APIRouter, Depends, HTTPException -from fastapi import status +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions router = APIRouter() @router.get("/", response_model=List[schemas.EnergyConversion]) -def all_energy_conversions(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.EnergyConversion]: +def get_all_energy_conversions(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.EnergyConversion]: """ Retrieve all energy conversions. """ - return crud.energy_conversion.get_multi(db, skip, limit) + return crud.energy_conversion.get_multi(db=db, skip=skip, limit=limit) @router.post("/", response_model=schemas.EnergyConversion, diff --git a/ensysmod/api/endpoints/energy_models.py b/ensysmod/api/endpoints/energy_models.py index f4905b4..25cef74 100644 --- a/ensysmod/api/endpoints/energy_models.py +++ b/ensysmod/api/endpoints/energy_models.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import FileResponse @@ -16,18 +16,18 @@ @router.get("/", response_model=List[schemas.EnergyModel]) -def all_models(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100, - dataset: Union[None, int] = None) -> List[schemas.EnergyModel]: +def get_all_models(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100, + dataset_id: Optional[int] = None) -> List[schemas.EnergyModel]: """ Retrieve all energy models. """ - if dataset is None: - return crud.energy_model.get_multi(db, skip=skip, limit=limit) + if dataset_id is None: + return crud.energy_model.get_multi(db=db, skip=skip, limit=limit) else: - return crud.energy_model.get_multi_by_dataset(db, dataset_id=dataset, skip=skip, limit=limit) + return crud.energy_model.get_multi_by_dataset(db=db, dataset_id=dataset_id, skip=skip, limit=limit) @router.get("/{model_id}", response_model=schemas.EnergyModel) @@ -35,7 +35,7 @@ def get_model(model_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ - Retrieve a energy model. + Retrieve an energy model. """ return crud.energy_model.get(db, id=model_id) @@ -68,7 +68,7 @@ def update_model(model_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ - Update a energy model. + Update an energy model. """ energy_model = crud.energy_model.get(db=db, id=model_id) if energy_model is None: @@ -82,7 +82,7 @@ def remove_model(model_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ - Delete a energy model. + Delete an energy model. """ energy_model = crud.energy_model.get(db=db, id=model_id) if energy_model is None: @@ -154,6 +154,8 @@ def myopic_optimize_model(model_id: int, esM = generate_esm_from_model(db=db, model=energy_model) zipped_result_file_path = myopic_optimize_esm(esM=esM, optimization_parameters=energy_model_optimization_parameters) - return FileResponse(zipped_result_file_path, - media_type="application/zip", - filename=f"{energy_model.name} {energy_model_optimization_parameters.start_year}-{energy_model_optimization_parameters.end_year}.zip") + return FileResponse( + zipped_result_file_path, + media_type="application/zip", + filename=f"{energy_model.name} {energy_model_optimization_parameters.start_year}-{energy_model_optimization_parameters.end_year}.zip" + ) diff --git a/ensysmod/api/endpoints/energy_sinks.py b/ensysmod/api/endpoints/energy_sinks.py index 13da20b..1bc666b 100644 --- a/ensysmod/api/endpoints/energy_sinks.py +++ b/ensysmod/api/endpoints/energy_sinks.py @@ -1,24 +1,23 @@ from typing import List -from fastapi import APIRouter, Depends, HTTPException -from fastapi import status +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions router = APIRouter() @router.get("/", response_model=List[schemas.EnergySink]) -def all_energy_sinks(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.EnergySink]: +def get_all_energy_sinks(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.EnergySink]: """ Retrieve all energy sinks. """ - return crud.energy_sink.get_multi(db, skip, limit) + return crud.energy_sink.get_multi(db=db, skip=skip, limit=limit) @router.post("/", response_model=schemas.EnergySink, diff --git a/ensysmod/api/endpoints/energy_sources.py b/ensysmod/api/endpoints/energy_sources.py index 504aad5..45fff40 100644 --- a/ensysmod/api/endpoints/energy_sources.py +++ b/ensysmod/api/endpoints/energy_sources.py @@ -1,24 +1,23 @@ from typing import List -from fastapi import APIRouter, Depends, HTTPException -from fastapi import status +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions router = APIRouter() @router.get("/", response_model=List[schemas.EnergySource]) -def all_energy_sources(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.EnergySource]: +def get_all_energy_sources(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.EnergySource]: """ Retrieve all energy sources. """ - return crud.energy_source.get_multi(db, skip, limit) + return crud.energy_source.get_multi(db=db, skip=skip, limit=limit) @router.post("/", response_model=schemas.EnergySource, diff --git a/ensysmod/api/endpoints/energy_storages.py b/ensysmod/api/endpoints/energy_storages.py index 78577af..64dcfbc 100644 --- a/ensysmod/api/endpoints/energy_storages.py +++ b/ensysmod/api/endpoints/energy_storages.py @@ -1,24 +1,23 @@ from typing import List -from fastapi import APIRouter, Depends, HTTPException -from fastapi import status +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions router = APIRouter() @router.get("/", response_model=List[schemas.EnergyStorage]) -def all_energy_storages(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.EnergyStorage]: +def get_all_energy_storages(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.EnergyStorage]: """ Retrieve all energy storages. """ - return crud.energy_storage.get_multi(db, skip, limit) + return crud.energy_storage.get_multi(db=db, skip=skip, limit=limit) @router.post("/", response_model=schemas.EnergyStorage, diff --git a/ensysmod/api/endpoints/energy_transmissions.py b/ensysmod/api/endpoints/energy_transmissions.py index bc57149..fd473eb 100644 --- a/ensysmod/api/endpoints/energy_transmissions.py +++ b/ensysmod/api/endpoints/energy_transmissions.py @@ -1,24 +1,23 @@ from typing import List -from fastapi import APIRouter, Depends, HTTPException -from fastapi import status +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions router = APIRouter() @router.get("/", response_model=List[schemas.EnergyTransmission]) -def all_energy_transmissions(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.EnergyTransmission]: +def get_all_energy_transmissions(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.EnergyTransmission]: """ Retrieve all energy transmissions. """ - return crud.energy_transmission.get_multi(db, skip, limit) + return crud.energy_transmission.get_multi(db=db, skip=skip, limit=limit) @router.post("/", response_model=schemas.EnergyTransmission, diff --git a/ensysmod/api/endpoints/regions.py b/ensysmod/api/endpoints/regions.py index 08dd4ec..25382bc 100644 --- a/ensysmod/api/endpoints/regions.py +++ b/ensysmod/api/endpoints/regions.py @@ -1,27 +1,27 @@ -from typing import List, Union +from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions router = APIRouter() @router.get("/", response_model=List[schemas.Region]) -def all_regions(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100, - dataset: Union[None, int] = None) -> List[schemas.Region]: +def get_all_regions(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100, + dataset_id: Optional[int] = None) -> List[schemas.Region]: """ Retrieve all energy regions. """ - if dataset is None: - return crud.region.get_multi(db, skip=skip, limit=limit) + if dataset_id is None: + return crud.region.get_multi(db=db, skip=skip, limit=limit) else: - return crud.region.get_multi_by_dataset(db, dataset_id=dataset, skip=skip, limit=limit) + return crud.region.get_multi_by_dataset(db=db, dataset_id=dataset_id, skip=skip, limit=limit) @router.get("/{region_id}", response_model=schemas.Region) diff --git a/ensysmod/api/endpoints/ts_capacity_fix.py b/ensysmod/api/endpoints/ts_capacity_fix.py index cab97e3..93e4890 100644 --- a/ensysmod/api/endpoints/ts_capacity_fix.py +++ b/ensysmod/api/endpoints/ts_capacity_fix.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions from ensysmod.schemas import CapacityFix @@ -11,18 +11,18 @@ @router.get("/", response_model=List[schemas.CapacityFix]) -def all_fix_capacities(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.CapacityFix]: +def get_all_fix_capacities(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.CapacityFix]: """ Retrieve all fix capacities. """ - return crud.capacity_fix.get_multi(db, skip=skip, limit=limit) + return crud.capacity_fix.get_multi(db=db, skip=skip, limit=limit) @router.get("/{ts_id}", response_model=schemas.CapacityFix) -def get_capacity_fix(ts_id: int, +def get_fix_capacity(ts_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ @@ -33,7 +33,7 @@ def get_capacity_fix(ts_id: int, @router.post("/", response_model=schemas.CapacityFix) -def create_capacity_fix(request: schemas.CapacityFixCreate, +def create_fix_capacity(request: schemas.CapacityFixCreate, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ @@ -77,7 +77,7 @@ def create_capacity_fix(request: schemas.CapacityFixCreate, @router.put("/{ts_id}", response_model=schemas.CapacityFix) -def update_capacity_fix(ts_id: int, +def update_fix_capacity(ts_id: int, request: schemas.CapacityFixUpdate, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): @@ -93,7 +93,7 @@ def update_capacity_fix(ts_id: int, @router.delete("/{ts_id}", response_model=schemas.CapacityFix) -def remove_capacity_fix(ts_id: int, +def remove_fix_capacity(ts_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ diff --git a/ensysmod/api/endpoints/ts_capacity_max.py b/ensysmod/api/endpoints/ts_capacity_max.py index 6548bd2..ff49f8c 100644 --- a/ensysmod/api/endpoints/ts_capacity_max.py +++ b/ensysmod/api/endpoints/ts_capacity_max.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions from ensysmod.schemas import CapacityMax @@ -11,18 +11,18 @@ @router.get("/", response_model=List[schemas.CapacityMax]) -def all_max_capacities(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.CapacityMax]: +def get_all_max_capacities(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.CapacityMax]: """ Retrieve all max capacities. """ - return crud.capacity_max.get_multi(db, skip=skip, limit=limit) + return crud.capacity_max.get_multi(db=db, skip=skip, limit=limit) @router.get("/{ts_id}", response_model=schemas.CapacityMax) -def get_capacity_max(ts_id: int, +def get_max_capacity(ts_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ @@ -32,7 +32,7 @@ def get_capacity_max(ts_id: int, @router.post("/", response_model=schemas.CapacityMax) -def create_capacity_max(request: schemas.CapacityMaxCreate, +def create_max_capacity(request: schemas.CapacityMaxCreate, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ @@ -75,7 +75,7 @@ def create_capacity_max(request: schemas.CapacityMaxCreate, @router.put("/{ts_id}", response_model=schemas.CapacityMax) -def update_capacity_max(ts_id: int, +def update_max_capacity(ts_id: int, request: schemas.CapacityMaxUpdate, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): @@ -90,7 +90,7 @@ def update_capacity_max(ts_id: int, @router.delete("/{ts_id}", response_model=schemas.CapacityMax) -def remove_capacity_max(ts_id: int, +def remove_max_capacity(ts_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ diff --git a/ensysmod/api/endpoints/ts_operation_rate_fix.py b/ensysmod/api/endpoints/ts_operation_rate_fix.py index 17793dc..cf52ae1 100644 --- a/ensysmod/api/endpoints/ts_operation_rate_fix.py +++ b/ensysmod/api/endpoints/ts_operation_rate_fix.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions from ensysmod.schemas import OperationRateFix @@ -11,18 +11,18 @@ @router.get("/", response_model=List[schemas.OperationRateFix]) -def all_fix_operation_rates(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.OperationRateFix]: +def get_all_fix_operation_rates(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.OperationRateFix]: """ Retrieve all fix operation rates. """ - return crud.operation_rate_fix.get_multi(db, skip=skip, limit=limit) + return crud.operation_rate_fix.get_multi(db=db, skip=skip, limit=limit) @router.get("/{ts_id}", response_model=schemas.OperationRateFix) -def get_operation_rate_fix(ts_id: int, +def get_fix_operation_rate(ts_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ @@ -32,7 +32,7 @@ def get_operation_rate_fix(ts_id: int, @router.post("/", response_model=schemas.OperationRateFix) -def create_operation_rate_fix(request: schemas.OperationRateFixCreate, +def create_fix_operation_rate(request: schemas.OperationRateFixCreate, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ @@ -76,7 +76,7 @@ def create_operation_rate_fix(request: schemas.OperationRateFixCreate, @router.put("/{ts_id}", response_model=schemas.OperationRateFix) -def update_operation_rate_fix(ts_id: int, +def update_fix_operation_rate(ts_id: int, request: schemas.OperationRateFixUpdate, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): @@ -91,7 +91,7 @@ def update_operation_rate_fix(ts_id: int, @router.delete("/{ts_id}", response_model=schemas.OperationRateFix) -def remove_operation_rate_fix(ts_id: int, +def remove_fix_operation_rate(ts_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ diff --git a/ensysmod/api/endpoints/ts_operation_rate_max.py b/ensysmod/api/endpoints/ts_operation_rate_max.py index 6e9d0c0..efc4d43 100644 --- a/ensysmod/api/endpoints/ts_operation_rate_max.py +++ b/ensysmod/api/endpoints/ts_operation_rate_max.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ensysmod import schemas, model, crud +from ensysmod import crud, model, schemas from ensysmod.api import deps, permissions from ensysmod.schemas import OperationRateMax @@ -11,18 +11,18 @@ @router.get("/", response_model=List[schemas.OperationRateMax]) -def all_max_operation_rates(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.OperationRateMax]: +def get_all_max_operation_rates(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.OperationRateMax]: """ Retrieve all max operation rates. """ - return crud.operation_rate_max.get_multi(db, skip=skip, limit=limit) + return crud.operation_rate_max.get_multi(db=db, skip=skip, limit=limit) @router.get("/{ts_id}", response_model=schemas.OperationRateMax) -def get_operation_rate_max(ts_id: int, +def get_max_operation_rate(ts_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ @@ -32,7 +32,7 @@ def get_operation_rate_max(ts_id: int, @router.post("/", response_model=schemas.OperationRateMax) -def create_operation_rate_max(request: schemas.OperationRateMaxCreate, +def create_max_operation_rate(request: schemas.OperationRateMaxCreate, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ @@ -76,7 +76,7 @@ def create_operation_rate_max(request: schemas.OperationRateMaxCreate, @router.put("/{ts_id}", response_model=schemas.OperationRateMax) -def update_operation_rate_max(ts_id: int, +def update_max_operation_rate(ts_id: int, request: schemas.OperationRateMaxUpdate, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): @@ -91,7 +91,7 @@ def update_operation_rate_max(ts_id: int, @router.delete("/{ts_id}", response_model=schemas.OperationRateMax) -def remove_operation_rate_max(ts_id: int, +def remove_max_operation_rate(ts_id: int, db: Session = Depends(deps.get_db), current: model.User = Depends(deps.get_current_user)): """ diff --git a/ensysmod/api/endpoints/users.py b/ensysmod/api/endpoints/users.py index fb491b8..a925e79 100644 --- a/ensysmod/api/endpoints/users.py +++ b/ensysmod/api/endpoints/users.py @@ -3,19 +3,19 @@ from fastapi import APIRouter, Depends from sqlalchemy.orm import Session -from ensysmod import schemas, crud, model +from ensysmod import crud, model, schemas from ensysmod.api import deps router = APIRouter() @router.get("/", response_model=List[schemas.User]) -def all_users(db: Session = Depends(deps.get_db), - current: model.User = Depends(deps.get_current_user), - skip: int = 0, - limit: int = 100) -> List[schemas.User]: +def get_all_users(db: Session = Depends(deps.get_db), + current: model.User = Depends(deps.get_current_user), + skip: int = 0, + limit: int = 100) -> List[schemas.User]: """ - Retrieve all user names from database. + Retrieve all users. """ - users = crud.user.get_multi(db, skip=skip, limit=limit) + users = crud.user.get_multi(db=db, skip=skip, limit=limit) return users diff --git a/ensysmod/crud/energy_model.py b/ensysmod/crud/energy_model.py index 0c93be9..dfc9d8d 100644 --- a/ensysmod/crud/energy_model.py +++ b/ensysmod/crud/energy_model.py @@ -1,8 +1,11 @@ +from typing import Optional + +from sqlalchemy import delete from sqlalchemy.orm import Session from ensysmod import crud from ensysmod.crud.base_depends_dataset import CRUDBaseDependsDataset -from ensysmod.model import EnergyModel +from ensysmod.model import EnergyModel, EnergyModelOptimization, EnergyModelOverride from ensysmod.schemas import EnergyModelCreate, EnergyModelUpdate @@ -31,5 +34,18 @@ def create(self, db: Session, *, obj_in: EnergyModelCreate) -> EnergyModel: return db_obj + def remove(self, db: Session, *, id: int) -> Optional[EnergyModel]: + if self.model.override_parameters is not None: + db.execute(delete(EnergyModelOverride).filter(EnergyModelOverride.ref_model == id)) + + if self.model.optimization_parameters is not None: + db.execute(delete(EnergyModelOptimization).filter(EnergyModelOptimization.ref_model == id)) + + model = db.query(self.model).get(id) + db.delete(model) + + db.commit() + return model + energy_model = CRUDEnergyModel(EnergyModel) diff --git a/ensysmod/crud/energy_model_override.py b/ensysmod/crud/energy_model_override.py index 774f155..a07cbc8 100644 --- a/ensysmod/crud/energy_model_override.py +++ b/ensysmod/crud/energy_model_override.py @@ -6,7 +6,7 @@ class CRUDEnergyModelOverride(CRUDBaseDependsComponentRegion[EnergyModelOverride, EnergyModelOverrideCreate, - EnergyModelOverrideUpdate]): + EnergyModelOverrideUpdate]): """ CRUD operations for EnergyModelOverride """ diff --git a/ensysmod/database/session.py b/ensysmod/database/session.py index f1bae44..65926b9 100644 --- a/ensysmod/database/session.py +++ b/ensysmod/database/session.py @@ -8,4 +8,4 @@ pool_pre_ping=True, connect_args={"check_same_thread": False}, poolclass=StaticPool) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, expire_on_commit=False, bind=engine) diff --git a/ensysmod/schemas/energy_model_optimization.py b/ensysmod/schemas/energy_model_optimization.py index 1f134f3..179d18c 100644 --- a/ensysmod/schemas/energy_model_optimization.py +++ b/ensysmod/schemas/energy_model_optimization.py @@ -13,8 +13,16 @@ class EnergyModelOptimizationBase(BaseModel): end_year: Optional[int] = Field(None, description="Year of the last optimization", example="2050") number_of_steps: Optional[int] = Field(None, description="Number of optimization runs excluding the start year", example="3") years_per_step: Optional[int] = Field(None, description="Number of years represented by one optimization run", example="10") - CO2_reference: Optional[float] = Field(None, description="CO2 emission reference value to which the reduction should be applied to", example="366") - CO2_reduction_targets: Optional[List[float]] = Field(None, description="CO2 reduction targets for all optimization periods, in percentages. If specified, the length of the list must equal the number of optimization steps, and a sink component named 'CO2 to environment' is required.", example="[0, 25, 50, 100]") + CO2_reference: Optional[float] = Field( + None, + description="CO2 emission reference value to which the reduction should be applied to", + example="366", + ) + CO2_reduction_targets: Optional[List[float]] = Field( + None, + description="CO2 reduction targets for all optimization periods, in percentages. If specified, the length of the list must equal the number of optimization steps, and a sink component named 'CO2 to environment' is required.", # noqa: E501 + example="[0, 25, 50, 100]", + ) # validators _valid_optimization_timeframe = root_validator(allow_reuse=True)(validators.validate_optimization_timeframe) diff --git a/ensysmod/schemas/energy_sink.py b/ensysmod/schemas/energy_sink.py index 4d304ce..cd63a44 100644 --- a/ensysmod/schemas/energy_sink.py +++ b/ensysmod/schemas/energy_sink.py @@ -18,15 +18,21 @@ class EnergySinkBase(BaseModel): Shared attributes for an energy sink. Used as a base class for all schemas. """ type = EnergyComponentType.SINK - commodity_cost: Optional[float] = Field(None, - description="Cost of the energy sink per unit of energy.", - example=42.2) - yearly_limit: Optional[float] = Field(None, - description="The yearly limit of the energy sink. If specified, commodity_limit_id must be specified as well.", - example=366.5) - commodity_limit_id: Optional[str] = Field(None, - description="Commodity limit ID of the energy sink. Required if yearly_limit is specified. The limit is shared among all components of the same commodity_limit_id.", - example="CO2") + commodity_cost: Optional[float] = Field( + None, + description="Cost of the energy sink per unit of energy.", + example=42.2, + ) + yearly_limit: Optional[float] = Field( + None, + description="The yearly limit of the energy sink. If specified, commodity_limit_id must be specified as well.", + example=366.5, + ) + commodity_limit_id: Optional[str] = Field( + None, + description="Commodity limit ID of the energy sink. Required if yearly_limit is specified. The limit is shared among all components of the same commodity_limit_id.", # noqa: E501 + example="CO2", + ) # validators _valid_type = validator("type", allow_reuse=True)(validators.validate_energy_component_type) diff --git a/ensysmod/schemas/energy_source.py b/ensysmod/schemas/energy_source.py index b22f451..fd056d4 100644 --- a/ensysmod/schemas/energy_source.py +++ b/ensysmod/schemas/energy_source.py @@ -18,15 +18,21 @@ class EnergySourceBase(BaseModel): Shared attributes for an energy source. Used as a base class for all schemas. """ type = EnergyComponentType.SOURCE - commodity_cost: Optional[float] = Field(None, - description="Cost of the energy source per unit of energy.", - example=42.2) - yearly_limit: Optional[float] = Field(None, - description="The yearly limit of the energy sink. If specified, commodity_limit_id must be specified as well.", - example=366.5) - commodity_limit_id: Optional[str] = Field(None, - description="Commodity limit ID of the energy sink. If specified, yearly_limit must be specified as well.", - example="CO2") + commodity_cost: Optional[float] = Field( + None, + description="Cost of the energy source per unit of energy.", + example=42.2, + ) + yearly_limit: Optional[float] = Field( + None, + description="The yearly limit of the energy sink. If specified, commodity_limit_id must be specified as well.", + example=366.5, + ) + commodity_limit_id: Optional[str] = Field( + None, + description="Commodity limit ID of the energy sink. If specified, yearly_limit must be specified as well.", + example="CO2", + ) # validators _valid_type = validator("type", allow_reuse=True)(validators.validate_energy_component_type) diff --git a/ensysmod/utils/validators.py b/ensysmod/utils/validators.py index e41b1c1..42ceef6 100644 --- a/ensysmod/utils/validators.py +++ b/ensysmod/utils/validators.py @@ -646,7 +646,7 @@ def validate_CO2_optimization(cls, values): :param CO2_reduction_targets: CO2 reduction targets for all optimization periods, in percentages. If specified, the length of the list must equal the number of optimization steps. :return: The validated CO2 optimization parameters. - """ + """ # noqa: E501 CO2_reference = values.get('CO2_reference') CO2_reduction_targets = values.get('CO2_reduction_targets') number_of_steps = values.get('number_of_steps') @@ -665,6 +665,8 @@ def validate_CO2_optimization(cls, values): raise ValueError("Values of CO2_reduction_targets must be between 0 and 100.") if len(CO2_reduction_targets) != number_of_steps+1: - raise ValueError(f"The number of values given in CO2_reduction_targets must match the number of optimization runs. Expected: {number_of_steps+1}, given: {len(CO2_reduction_targets)}.") + raise ValueError( + f"The number of values given in CO2_reduction_targets must match the number of optimization runs. Expected: {number_of_steps+1}, given: {len(CO2_reduction_targets)}." # noqa: E501 + ) return values diff --git a/pyproject.toml b/pyproject.toml index ea8e71c..627acfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "passlib>=1.7.4", "bcrypt>=3.2.0", "numpy~=1.24.4", - "pandas>=2.0.1", + "pandas~=1.5.3", "openpyxl>=3.0.9", "FINE>=2.2.2", "python-multipart>=0.0.6", @@ -60,11 +60,14 @@ where = ["ensysmod"] testpaths = "tests" [tool.pytest.ini_options] -adopts = [ - "--strict", +addopts = [ + "--strict-markers", "--doctest-modules", "--durations=0", - ] +] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", +] [tool.coverage.report] # https://stackoverflow.com/a/5850364 diff --git a/requirements.txt b/requirements.txt index f20fd54..b6c792f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ passlib==1.7.4 bcrypt==3.2.0 numpy==1.24.4 -pandas==2.0.1 +pandas==1.5.3 openpyxl==3.0.9 FINE==2.2.2 diff --git a/tests/api/test_dataset_download.py b/tests/api/test_dataset_download.py new file mode 100644 index 0000000..6557e36 --- /dev/null +++ b/tests/api/test_dataset_download.py @@ -0,0 +1,21 @@ +from typing import Dict + +import pytest +from fastapi import status +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from tests.utils import data_generator + + +@pytest.mark.slow +@pytest.mark.parametrize("data_folder", ["1node_Example", "Multi-regional_Example"]) +def test_download_dataset_zip(client: TestClient, db: Session, normal_user_headers: Dict[str, str], data_folder: str): + """ + Test downloading a dataset. + """ + dataset = data_generator.create_example_dataset(db, data_folder) + + response = client.get(f"/datasets/{dataset.id}/download", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + assert response.headers["Content-Type"] == "application/zip" diff --git a/tests/api/test_dataset_upload.py b/tests/api/test_dataset_upload.py new file mode 100644 index 0000000..9dbddc5 --- /dev/null +++ b/tests/api/test_dataset_upload.py @@ -0,0 +1,38 @@ +from typing import Dict +from zipfile import ZipFile + +import pytest +from fastapi import status +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from tests.utils import data_generator + + +@pytest.mark.slow +@pytest.mark.parametrize("data_folder", ["1node_Example", "Multi-regional_Example"]) +def test_upload_dataset_zip(client: TestClient, db: Session, normal_user_headers: Dict[str, str], data_folder: str): + """ + Test uploading a dataset. + """ + # Create a dataset + dataset = data_generator.random_existing_dataset(db) + + # Upload a zip file + zip_file_path = data_generator.get_dataset_zip(data_folder) + + # print all the contents of the zip file + print(f"Zip file contents of {zip_file_path}:") + with ZipFile(zip_file_path, 'r') as zip_file: + for file in zip_file.namelist(): + print(file) + + response = client.post( + f"/datasets/{dataset.id}/upload", + headers=normal_user_headers, + files={"file": ("dataset.zip", open(zip_file_path, "rb"), "application/zip")}, + ) + print(response.text) + assert response.status_code == status.HTTP_200_OK + + # TODO Check that the dataset has been updated diff --git a/tests/api/test_datasets.py b/tests/api/test_datasets.py index 46a827a..af21e53 100644 --- a/tests/api/test_datasets.py +++ b/tests/api/test_datasets.py @@ -6,8 +6,40 @@ from sqlalchemy.orm import Session from ensysmod.schemas import DatasetCreate, DatasetUpdate -from tests.utils import data_generator as data_gen -from tests.utils.utils import random_lower_string +from tests.utils import data_generator +from tests.utils.utils import clear_database, random_lower_string + + +def test_get_all_datasets(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving all datasets. + """ + clear_database(db) + dataset1 = data_generator.random_existing_dataset(db) + dataset2 = data_generator.random_existing_dataset(db) + + response = client.get("/datasets/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + dataset_list = response.json() + assert len(dataset_list) == 2 + assert dataset_list[0]["name"] == dataset1.name + assert dataset_list[0]["id"] == dataset1.id + assert dataset_list[1]["name"] == dataset2.name + assert dataset_list[1]["id"] == dataset2.id + + +def test_get_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving a dataset. + """ + dataset = data_generator.random_existing_dataset(db) + response = client.get(f"/datasets/{dataset.id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + retrieved_dataset = response.json() + assert retrieved_dataset["name"] == dataset.name + assert retrieved_dataset["id"] == dataset.id def test_create_dataset(client: TestClient, normal_user_headers: Dict[str, str]): @@ -15,7 +47,7 @@ def test_create_dataset(client: TestClient, normal_user_headers: Dict[str, str]) Test creating a dataset. """ # Create a dataset - create_request = data_gen.random_dataset_create() + create_request = data_generator.random_dataset_create() response = client.post( "/datasets/", @@ -33,7 +65,7 @@ def test_create_existing_dataset(db: Session, client: TestClient, normal_user_he """ Test creating an existing dataset. """ - existing_dataset = data_gen.random_existing_dataset(db) + existing_dataset = data_generator.random_existing_dataset(db) print(existing_dataset.name) create_request = DatasetCreate(**jsonable_encoder(existing_dataset)) response = client.post( @@ -44,15 +76,17 @@ def test_create_existing_dataset(db: Session, client: TestClient, normal_user_he assert response.status_code == status.HTTP_409_CONFLICT -def test_update_existing_dataset(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): +def test_update_dataset(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): """ - Test updating an existing dataset. + Test updating a dataset. """ - existing_dataset = data_gen.random_existing_dataset(db) + existing_dataset = data_generator.random_existing_dataset(db) print(existing_dataset.name) + update_request = DatasetUpdate(**jsonable_encoder(existing_dataset)) - new_description = random_lower_string() - update_request.description = new_description + update_request.name = f"New Dataset Name-{random_lower_string()}" + update_request.description = f"New Dataset Description-{random_lower_string()}" + response = client.put( f"/datasets/{existing_dataset.id}", headers=normal_user_headers, @@ -65,11 +99,11 @@ def test_update_existing_dataset(db: Session, client: TestClient, normal_user_he assert updated_dataset['description'] == update_request.description -def test_delete_existing_dataset(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): +def test_remove_dataset(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): """ - Test deleting an existing dataset. + Test deleting a dataset. """ - existing_dataset = data_gen.random_existing_dataset(db) + existing_dataset = data_generator.random_existing_dataset(db) response = client.delete( f"/datasets/{existing_dataset.id}", headers=normal_user_headers diff --git a/tests/api/test_energy_commodities.py b/tests/api/test_energy_commodities.py index 691bf40..8b5824d 100644 --- a/tests/api/test_energy_commodities.py +++ b/tests/api/test_energy_commodities.py @@ -5,15 +5,74 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from ensysmod.schemas import EnergyCommodityCreate -from tests.utils import data_generator as data_gen +from ensysmod.schemas import EnergyCommodityCreate, EnergyCommodityUpdate +from tests.utils import data_generator +from tests.utils.utils import clear_database, random_lower_string -def test_create_energy_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_get_all_commodities(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy commodity. + Test retrieving all commodities. """ - create_request = data_gen.random_energy_commodity_create(db) + clear_database(db) + commodity1 = data_generator.random_existing_energy_commodity(db) + commodity2 = data_generator.random_existing_energy_commodity(db) + + response = client.get("/commodities/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + commodity_list = response.json() + assert len(commodity_list) == 2 + assert commodity_list[0]["name"] == commodity1.name + assert commodity_list[0]["id"] == commodity1.id + assert commodity_list[1]["name"] == commodity2.name + assert commodity_list[1]["id"] == commodity2.id + + +def test_get_all_commodities_specific_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving all commodities belonging to a specific dataset. + """ + clear_database(db) + commodity1 = data_generator.random_existing_energy_commodity(db) + commodity2 = data_generator.random_existing_energy_commodity(db) + + response1 = client.get("/commodities/", headers=normal_user_headers, params={"dataset_id": commodity1.dataset.id}) + assert response1.status_code == status.HTTP_200_OK + + commodity_list1 = response1.json() + assert len(commodity_list1) == 1 + assert commodity_list1[0]["name"] == commodity1.name + assert commodity_list1[0]["id"] == commodity1.id + + response2 = client.get("/commodities/", headers=normal_user_headers, params={"dataset_id": commodity2.dataset.id}) + assert response2.status_code == status.HTTP_200_OK + + commodity_list2 = response2.json() + assert len(commodity_list2) == 1 + assert commodity_list2[0]["name"] == commodity2.name + assert commodity_list2[0]["id"] == commodity2.id + + +def test_get_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving a commodity. + """ + commodity = data_generator.random_existing_energy_commodity(db) + response = client.get(f"/commodities/{commodity.id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + retrieved_commodity = response.json() + print(retrieved_commodity) + assert retrieved_commodity["name"] == commodity.name + assert retrieved_commodity["id"] == commodity.id + + +def test_create_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating an energy commodity. + """ + create_request = data_generator.random_energy_commodity_create(db) response = client.post("/commodities/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -24,23 +83,110 @@ def test_create_energy_commodity(client: TestClient, normal_user_headers: Dict[s assert created_commodity["unit"] == create_request.unit -def test_create_existing_energy_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_existing_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a existing energy commodity. + Test creating an existing energy commodity. """ - existing_commodity = data_gen.random_existing_energy_commodity(db) + existing_commodity = data_generator.random_existing_energy_commodity(db) create_request = EnergyCommodityCreate(**jsonable_encoder(existing_commodity)) response = client.post("/commodities/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_409_CONFLICT -def test_create_energy_commodity_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_commodity_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy commodity. + Test creating an energy commodity. """ - create_request = data_gen.random_energy_commodity_create(db) + create_request = data_generator.random_energy_commodity_create(db) create_request.ref_dataset = 123456 # ungültige Anfrage response = client.post("/commodities/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND -# TODO Add more test cases + +def test_create_multiple_commodities_same_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating multiple commodities on the same dataset. + """ + existing_commodity = data_generator.random_existing_energy_commodity(db) + dataset_id = existing_commodity.dataset.id + + # Create a new commodity on the same dataset + create_request = data_generator.random_energy_commodity_create(db) + create_request.ref_dataset = dataset_id + + response = client.post("/commodities/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_200_OK + second_commodity = response.json() + + # Check that the dataset has two commodities + get_response = client.get( + "/commodities/", + headers=normal_user_headers, + params={"dataset_id": dataset_id}, + ) + assert get_response.status_code == status.HTTP_200_OK + + commodity_list = get_response.json() + assert len(commodity_list) == 2 + assert commodity_list[0]["name"] == existing_commodity.name + assert commodity_list[0]["id"] == existing_commodity.id + assert commodity_list[1]["name"] == second_commodity["name"] + assert commodity_list[1]["id"] == second_commodity["id"] + + +def test_update_commodity(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): + """ + Test updating a commodity. + """ + existing_commodity = data_generator.random_existing_energy_commodity(db) + print(existing_commodity.name) + + update_request = EnergyCommodityUpdate(**jsonable_encoder(existing_commodity)) + update_request.name = f"New Commodity Name-{random_lower_string()}" + update_request.unit = f"New Commodity Unit-{random_lower_string()}" + update_request.description = f"New Commodity Description-{random_lower_string()}" + + response = client.put( + f"/commodities/{existing_commodity.id}", + headers=normal_user_headers, + data=update_request.json(), + ) + assert response.status_code == status.HTTP_200_OK + + updated_commodity = response.json() + assert updated_commodity["name"] == update_request.name + assert updated_commodity["unit"] == update_request.unit + assert updated_commodity["description"] == update_request.description + + +def test_remove_commodity(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): + """ + Test deleting a commodity. + """ + # Create a dataset with two commodities + first_commodity = data_generator.random_existing_energy_commodity(db) + dataset_id = first_commodity.dataset.id + + create_request = data_generator.random_energy_commodity_create(db) + create_request.ref_dataset = dataset_id + + response = client.post("/commodities/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_200_OK + second_commodity = response.json() + + # Delete the first commodity + response = client.delete(f"/commodities/{first_commodity.id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + # Check that the dataset only has the second commodity + get_response = client.get( + "/commodities/", + headers=normal_user_headers, + params={"dataset_id": dataset_id}, + ) + assert get_response.status_code == status.HTTP_200_OK + + commodity_list = get_response.json() + assert len(commodity_list) == 1 + assert commodity_list[0]["name"] == second_commodity["name"] + assert commodity_list[0]["id"] == second_commodity["id"] diff --git a/tests/api/test_energy_conversions.py b/tests/api/test_energy_conversions.py index 562a91f..237d5f9 100644 --- a/tests/api/test_energy_conversions.py +++ b/tests/api/test_energy_conversions.py @@ -5,15 +5,35 @@ from sqlalchemy.orm import Session from ensysmod.model import EnergyComponentType -from tests.utils import data_generator as data_gen +from tests.utils import data_generator from tests.utils.assertions import assert_energy_component +from tests.utils.utils import clear_database -def test_create_energy_conversion(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_get_all_energy_conversions(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy conversion. + Test retrieving all energy converesions. """ - create_request = data_gen.random_energy_conversion_create(db) + clear_database(db) + conversion1 = data_generator.random_existing_energy_conversion(db) + conversion2 = data_generator.random_existing_energy_conversion(db) + + response = client.get("/conversions/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + conversion_list = response.json() + assert len(conversion_list) == 2 + assert conversion_list[0]["component"]["name"] == conversion1.component.name + assert conversion_list[0]["component"]["id"] == conversion1.component.id + assert conversion_list[1]["component"]["name"] == conversion2.component.name + assert conversion_list[1]["component"]["id"] == conversion2.component.id + + +def test_create_conversion(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating an energy conversion. + """ + create_request = data_generator.random_energy_conversion_create(db) response = client.post("/conversions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -22,33 +42,32 @@ def test_create_energy_conversion(client: TestClient, normal_user_headers: Dict[ assert created_conversion["commodity_unit"]["name"] == create_request.commodity_unit -def test_create_existing_energy_conversion(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_existing_conversion(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a existing energy conversion. + Test creating an existing energy conversion. """ - create_request = data_gen.random_energy_conversion_create(db) + create_request = data_generator.random_energy_conversion_create(db) response = client.post("/conversions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK response = client.post("/conversions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_409_CONFLICT -def test_create_energy_conversion_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_conversion_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy conversion. + Test creating an energy conversion. """ - create_request = data_gen.random_energy_conversion_create(db) + create_request = data_generator.random_energy_conversion_create(db) create_request.ref_dataset = 132456 # ungültige Anfrage response = client.post("/conversions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND -def test_create_energy_conversion_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], - db: Session): +def test_create_conversion_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy conversion. + Test creating an energy conversion. """ - create_request = data_gen.random_energy_conversion_create(db) + create_request = data_generator.random_energy_conversion_create(db) create_request.commodity_unit = "0" # ungültige Anfrage response = client.post("/conversions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/api/test_energy_model_optimization.py b/tests/api/test_energy_model_optimization.py new file mode 100644 index 0000000..3be3509 --- /dev/null +++ b/tests/api/test_energy_model_optimization.py @@ -0,0 +1,22 @@ +from typing import Dict + +import pytest +from fastapi import status +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from tests.utils import data_generator + + +@pytest.mark.slow +@pytest.mark.parametrize("data_folder", ["1node_Example", "Multi-regional_Example"]) +def test_optimize_model(client: TestClient, normal_user_headers: Dict[str, str], db: Session, data_folder: str): + """ + Test optimizing an energy model. + """ + example_model = data_generator.create_example_model(db, data_folder) + response = client.get(f"/models/{example_model.id}/optimize/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + assert response.headers["Content-Type"] == "application/vnd.openxmlformats-officedocument. spreadsheetml.sheet" + +# TODO Add test for myopic_optimize_model diff --git a/tests/api/test_energy_models.py b/tests/api/test_energy_models.py index e943b6c..8923bdb 100644 --- a/tests/api/test_energy_models.py +++ b/tests/api/test_energy_models.py @@ -4,16 +4,49 @@ from fastapi.encoders import jsonable_encoder from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from tests.utils import data_generator as data_gen -from ensysmod.schemas import EnergyModelCreate +from ensysmod.schemas import EnergyModelCreate, EnergyModelUpdate +from tests.utils import data_generator +from tests.utils.utils import clear_database, random_lower_string -def test_create_energy_model(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_get_all_models(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy model. + Test retrieving all energy models. """ - create_request = data_gen.random_energy_model_create(db) + clear_database(db) + model1 = data_generator.random_existing_energy_model(db) + model2 = data_generator.random_existing_energy_model(db) + + response = client.get("/models/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + model_list = response.json() + assert len(model_list) == 2 + assert model_list[0]["name"] == model1.name + assert model_list[0]["dataset"]["id"] == model1.dataset.id + assert model_list[1]["name"] == model2.name + assert model_list[1]["dataset"]["id"] == model2.dataset.id + + +def test_get_model(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving an energy model. + """ + model = data_generator.random_existing_energy_model(db) + response = client.get(f"/models/{model.id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + retrieved_model = response.json() + assert retrieved_model["name"] == model.name + assert retrieved_model["dataset"]["id"] == model.dataset.id + + +def test_create_model(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating an energy model. + """ + create_request = data_generator.random_energy_model_create(db) response = client.post("/models/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -23,11 +56,11 @@ def test_create_energy_model(client: TestClient, normal_user_headers: Dict[str, assert created_model["description"] == create_request.description -def test_create_existing_energy_model(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_existing_model(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a existing energy model. + Test creating an existing energy model. """ - existing_model = data_gen.random_existing_energy_model(db) + existing_model = data_generator.random_existing_energy_model(db) existing_model.override_parameters = [] create_request = EnergyModelCreate(**jsonable_encoder(existing_model)) response = client.post("/models/", headers=normal_user_headers, data=create_request.json()) @@ -36,9 +69,9 @@ def test_create_existing_energy_model(client: TestClient, normal_user_headers: D def test_create_energy_model_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy model. + Test creating an energy model. """ - create_request = data_gen.random_energy_model_create(db) + create_request = data_generator.random_energy_model_create(db) create_request.ref_dataset = 123456 # ungültige Anfrage response = client.post("/models/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -46,9 +79,9 @@ def test_create_energy_model_unknown_dataset(client: TestClient, normal_user_hea def test_create_energy_model_with_override_parameters(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy model with override parameters. + Test creating an energy model with override parameters. """ - create_request = data_gen.random_energy_model_create(db) + create_request = data_generator.random_energy_model_create(db) response = client.post("/models/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -61,9 +94,9 @@ def test_create_energy_model_with_override_parameters(client: TestClient, normal def test_create_energy_model_with_optimization_parameters(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy model with optimization parameters. + Test creating an energy model with optimization parameters. """ - create_request = data_gen.random_energy_model_create(db) + create_request = data_generator.random_energy_model_create(db) response = client.post("/models/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -75,4 +108,36 @@ def test_create_energy_model_with_optimization_parameters(client: TestClient, no assert created_model["optimization_parameters"][0]["years_per_step"] == 10 -# TODO Add more test cases +def test_update_energy_model(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): + """ + Test updating an energy model. + """ + existing_model = data_generator.random_existing_energy_model(db) + print(existing_model.name) + + update_request = EnergyModelUpdate(**jsonable_encoder(existing_model)) + update_request.name = f"New Energy Model Name-{random_lower_string()}" + + print(update_request.json()) + + response = client.put( + f"/models/{existing_model.id}", + headers=normal_user_headers, + data=update_request.json(), + ) + assert response.status_code == status.HTTP_200_OK + + updated_model = response.json() + assert updated_model["name"] == update_request.name + + +def test_remove_energy_model(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): + """ + Test deleting an energy_model. + """ + existing_model = data_generator.random_existing_energy_model(db) + response = client.delete( + f"/models/{existing_model.id}", + headers=normal_user_headers + ) + assert response.status_code == status.HTTP_200_OK diff --git a/tests/api/test_energy_sinks.py b/tests/api/test_energy_sinks.py index 5d19ff8..fe2cffc 100644 --- a/tests/api/test_energy_sinks.py +++ b/tests/api/test_energy_sinks.py @@ -5,15 +5,35 @@ from sqlalchemy.orm import Session from ensysmod.model import EnergyComponentType -from tests.utils import data_generator as data_gen +from tests.utils import data_generator from tests.utils.assertions import assert_energy_component +from tests.utils.utils import clear_database -def test_create_energy_sink(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_get_all_energy_sinks(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy sink. + Test retrieving all energy sinks. """ - create_request = data_gen.random_energy_sink_create(db) + clear_database(db) + sink1 = data_generator.random_existing_energy_sink(db) + sink2 = data_generator.random_existing_energy_sink(db) + + response = client.get("/sinks/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + sink_list = response.json() + assert len(sink_list) == 2 + assert sink_list[0]["component"]["name"] == sink1.component.name + assert sink_list[0]["component"]["id"] == sink1.component.id + assert sink_list[1]["component"]["name"] == sink2.component.name + assert sink_list[1]["component"]["id"] == sink2.component.id + + +def test_create_sink(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating an energy sink. + """ + create_request = data_generator.random_energy_sink_create(db) response = client.post("/sinks/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -22,32 +42,32 @@ def test_create_energy_sink(client: TestClient, normal_user_headers: Dict[str, s assert created_sinks["commodity"]["name"] == create_request.commodity -def test_create_existing_energy_sink(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_existing_sink(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a existing energy sink. + Test creating an existing energy sink. """ - create_request = data_gen.random_energy_sink_create(db) + create_request = data_generator.random_energy_sink_create(db) response = client.post("/sinks/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK response = client.post("/sinks/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_409_CONFLICT -def test_create_energy_sink_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_sink_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy sink. + Test creating an energy sink. """ - create_request = data_gen.random_energy_sink_create(db) + create_request = data_generator.random_energy_sink_create(db) create_request.ref_dataset = 123456 # ungültige Anfrage response = client.post("/sinks/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND -def test_create_energy_sink_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_sink_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy sink. + Test creating an energy sink. """ - create_request = data_gen.random_energy_sink_create(db) + create_request = data_generator.random_energy_sink_create(db) create_request.commodity = "0" # ungültige Anfrage response = client.post("/sinks/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/api/test_energy_sources.py b/tests/api/test_energy_sources.py index dc47b3e..b6159af 100644 --- a/tests/api/test_energy_sources.py +++ b/tests/api/test_energy_sources.py @@ -5,15 +5,35 @@ from sqlalchemy.orm import Session from ensysmod.model import EnergyComponentType -from tests.utils import data_generator as data_gen +from tests.utils import data_generator from tests.utils.assertions import assert_energy_component +from tests.utils.utils import clear_database -def test_create_energy_source(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_get_all_energy_sources(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy source. + Test retrieving all energy sources. """ - create_request = data_gen.random_energy_source_create(db) + clear_database(db) + source1 = data_generator.random_existing_energy_source(db) + source2 = data_generator.random_existing_energy_source(db) + + response = client.get("/sources/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + source_list = response.json() + assert len(source_list) == 2 + assert source_list[0]["component"]["name"] == source1.component.name + assert source_list[0]["component"]["id"] == source1.component.id + assert source_list[1]["component"]["name"] == source2.component.name + assert source_list[1]["component"]["id"] == source2.component.id + + +def test_create_source(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating an energy source. + """ + create_request = data_generator.random_energy_source_create(db) response = client.post("/sources/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -22,32 +42,32 @@ def test_create_energy_source(client: TestClient, normal_user_headers: Dict[str, assert created_source["commodity"]["name"] == create_request.commodity -def test_create_existing_energy_source(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_existing_source(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a existing energy source. + Test creating an existing energy source. """ - create_request = data_gen.random_energy_source_create(db) + create_request = data_generator.random_energy_source_create(db) response = client.post("/sources/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK response = client.post("/sources/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_409_CONFLICT -def test_create_energy_source_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_source_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy source. + Test creating an energy source. """ - create_request = data_gen.random_energy_source_create(db) + create_request = data_generator.random_energy_source_create(db) create_request.ref_dataset = 123456 # ungültige Anfrage response = client.post("/sources/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND -def test_create_energy_source_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_source_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy source. + Test creating an energy source. """ - create_request = data_gen.random_energy_source_create(db) + create_request = data_generator.random_energy_source_create(db) create_request.commodity = "0" # ungültige Anfrage response = client.post("/sources/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/api/test_energy_storages.py b/tests/api/test_energy_storages.py index 4a47986..96aa053 100644 --- a/tests/api/test_energy_storages.py +++ b/tests/api/test_energy_storages.py @@ -5,15 +5,35 @@ from sqlalchemy.orm import Session from ensysmod.model import EnergyComponentType -from tests.utils import data_generator as data_gen +from tests.utils import data_generator from tests.utils.assertions import assert_energy_component +from tests.utils.utils import clear_database -def test_create_energy_storage(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_get_all_energy_storages(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy storage. + Test retrieving all energy storages. """ - create_request = data_gen.random_energy_storage_create(db) + clear_database(db) + storage1 = data_generator.random_existing_energy_storage(db) + storage2 = data_generator.random_existing_energy_storage(db) + + response = client.get("/storages/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + storage_list = response.json() + assert len(storage_list) == 2 + assert storage_list[0]["component"]["name"] == storage1.component.name + assert storage_list[0]["component"]["id"] == storage1.component.id + assert storage_list[1]["component"]["name"] == storage2.component.name + assert storage_list[1]["component"]["id"] == storage2.component.id + + +def test_create_storage(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating an energy storage. + """ + create_request = data_generator.random_energy_storage_create(db) response = client.post("/storages/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -22,32 +42,32 @@ def test_create_energy_storage(client: TestClient, normal_user_headers: Dict[str assert created_storage["commodity"]["name"] == create_request.commodity -def test_create_existing_energy_storage(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_existing_storage(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a existing energy storage. + Test creating an existing energy storage. """ - create_request = data_gen.random_energy_storage_create(db) + create_request = data_generator.random_energy_storage_create(db) response = client.post("/storages/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK response = client.post("/storages/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_409_CONFLICT -def test_create_energy_storage_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_storage_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy storage. + Test creating an energy storage. """ - create_request = data_gen.random_energy_storage_create(db) + create_request = data_generator.random_energy_storage_create(db) create_request.ref_dataset = 123456 # ungültige Anfrage response = client.post("/storages/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND -def test_create_energy_storage_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_storage_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy storage. + Test creating an energy storage. """ - create_request = data_gen.random_energy_storage_create(db) + create_request = data_generator.random_energy_storage_create(db) create_request.commodity = "0" # ungültige Anfrage response = client.post("/storages/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/api/test_energy_transmissions.py b/tests/api/test_energy_transmissions.py index 2b3e55d..4f4866e 100644 --- a/tests/api/test_energy_transmissions.py +++ b/tests/api/test_energy_transmissions.py @@ -5,15 +5,35 @@ from sqlalchemy.orm import Session from ensysmod.model import EnergyComponentType -from tests.utils import data_generator as data_gen +from tests.utils import data_generator from tests.utils.assertions import assert_energy_component +from tests.utils.utils import clear_database -def test_create_energy_transmission(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_get_all_energy_transmissions(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy transmission. + Test retrieving all energy transmissions. """ - create_request = data_gen.random_energy_transmission_create(db) + clear_database(db) + transmission1 = data_generator.random_existing_energy_transmission(db) + transmission2 = data_generator.random_existing_energy_transmission(db) + + response = client.get("/transmissions/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + transmission_list = response.json() + assert len(transmission_list) == 2 + assert transmission_list[0]["component"]["name"] == transmission1.component.name + assert transmission_list[0]["component"]["id"] == transmission1.component.id + assert transmission_list[1]["component"]["name"] == transmission2.component.name + assert transmission_list[1]["component"]["id"] == transmission2.component.id + + +def test_create_transmission(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating an energy transmission. + """ + create_request = data_generator.random_energy_transmission_create(db) response = client.post("/transmissions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -22,34 +42,32 @@ def test_create_energy_transmission(client: TestClient, normal_user_headers: Dic assert created_transmission["commodity"]["name"] == create_request.commodity -def test_create_existing_energy_transmission(client: TestClient, normal_user_headers: Dict[str, str], db: Session): +def test_create_existing_transmission(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a existing energy transmission. + Test creating an existing energy transmission. """ - create_request = data_gen.random_energy_transmission_create(db) + create_request = data_generator.random_energy_transmission_create(db) response = client.post("/transmissions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK response = client.post("/transmissions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_409_CONFLICT -def test_create_energy_transmission_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], - db: Session): +def test_create_transmission_unknown_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy transmission. + Test creating an energy transmission. """ - create_request = data_gen.random_energy_transmission_create(db) + create_request = data_generator.random_energy_transmission_create(db) create_request.ref_dataset = 123456 # ungültige Anfrage response = client.post("/transmissions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND -def test_create_energy_transmission_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], - db: Session): +def test_create_transmission_unknown_commodity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a energy transmission. + Test creating an energy transmission. """ - create_request = data_gen.random_energy_transmission_create(db) + create_request = data_generator.random_energy_transmission_create(db) create_request.commodity = "0" # ungültige Anfrage response = client.post("/transmissions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/api/test_regions.py b/tests/api/test_regions.py index b63a99c..cbb7979 100644 --- a/tests/api/test_regions.py +++ b/tests/api/test_regions.py @@ -5,15 +5,74 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from ensysmod.schemas import RegionCreate -from tests.utils import data_generator as data_gen +from ensysmod.schemas import RegionCreate, RegionUpdate +from tests.utils import data_generator +from tests.utils.utils import clear_database, random_lower_string + + +def test_get_all_regions(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving all regions. + """ + clear_database(db) + region1 = data_generator.random_existing_region(db) + region2 = data_generator.random_existing_region(db) + + response = client.get("/regions/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + region_list = response.json() + assert len(region_list) == 2 + assert region_list[0]["name"] == region1.name + assert region_list[0]["id"] == region1.id + assert region_list[1]["name"] == region2.name + assert region_list[1]["id"] == region2.id + + +def test_get_all_regions_specific_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving all regions belonging to a specific dataset. + """ + clear_database(db) + region1 = data_generator.random_existing_region(db) + region2 = data_generator.random_existing_region(db) + + response1 = client.get("/regions/", headers=normal_user_headers, params={"dataset_id": region1.dataset.id}) + assert response1.status_code == status.HTTP_200_OK + + region_list1 = response1.json() + assert len(region_list1) == 1 + assert region_list1[0]["name"] == region1.name + assert region_list1[0]["id"] == region1.id + + response2 = client.get("/regions/", headers=normal_user_headers, params={"dataset_id": region2.dataset.id}) + assert response2.status_code == status.HTTP_200_OK + + region_list2 = response2.json() + assert len(region_list2) == 1 + assert region_list2[0]["name"] == region2.name + assert region_list2[0]["id"] == region2.id + + +def test_get_region(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving a region. + """ + region = data_generator.random_existing_region(db) + response = client.get(f"/regions/{region.id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + retrieved_region = response.json() + print(retrieved_region) + assert retrieved_region["name"] == region.name + assert retrieved_region["id"] == region.id def test_create_region(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ Test creating a region. """ - create_request = data_gen.random_region_create(db) + create_request = data_generator.random_region_create(db) response = client.post("/regions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK @@ -24,9 +83,9 @@ def test_create_region(client: TestClient, normal_user_headers: Dict[str, str], def test_create_existing_region(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ - Test creating a existing region. + Test creating an existing region. """ - existing_region = data_gen.random_existing_region(db) + existing_region = data_generator.random_existing_region(db) create_request = RegionCreate(**jsonable_encoder(existing_region)) response = client.post("/regions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_409_CONFLICT @@ -36,9 +95,92 @@ def test_create_region_unknown_dataset(client: TestClient, normal_user_headers: """ Test creating a region. """ - create_request = data_gen.random_region_create(db) + create_request = data_generator.random_region_create(db) create_request.ref_dataset = 123456 # ungültige Anfrage response = client.post("/regions/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_404_NOT_FOUND -# TODO Add more test cases + +def test_create_multiple_regions_same_dataset(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test creating multiple regions on the same dataset. + """ + existing_region = data_generator.random_existing_region(db) + dataset_id = existing_region.dataset.id + + # Create a new region on the same dataset + create_request = data_generator.random_region_create(db) + create_request.ref_dataset = dataset_id + + response = client.post("/regions/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_200_OK + second_region = response.json() + + # Check that the dataset has two regions + get_response = client.get( + "/regions/", + headers=normal_user_headers, + params={"dataset_id": dataset_id}, + ) + assert get_response.status_code == status.HTTP_200_OK + + region_list = get_response.json() + assert len(region_list) == 2 + assert region_list[0]["name"] == existing_region.name + assert region_list[0]["id"] == existing_region.id + assert region_list[1]["name"] == second_region["name"] + assert region_list[1]["id"] == second_region["id"] + + +def test_update_region(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): + """ + Test updating a region. + """ + existing_region = data_generator.random_existing_region(db) + print(existing_region.name) + + update_request = RegionUpdate(**jsonable_encoder(existing_region)) + update_request.name = f"New Region Name-{random_lower_string()}" + + response = client.put( + f"/regions/{existing_region.id}", + headers=normal_user_headers, + data=update_request.json(), + ) + assert response.status_code == status.HTTP_200_OK + + updated_region = response.json() + assert updated_region["name"] == update_request.name + + +def test_remove_region(db: Session, client: TestClient, normal_user_headers: Dict[str, str]): + """ + Test deleting a region. + """ + # Create a dataset with two commodities + first_region = data_generator.random_existing_region(db) + dataset_id = first_region.dataset.id + + create_request = data_generator.random_region_create(db) + create_request.ref_dataset = dataset_id + + response = client.post("/regions/", headers=normal_user_headers, data=create_request.json()) + assert response.status_code == status.HTTP_200_OK + second_region = response.json() + + # Delete the first region + response = client.delete(f"/regions/{first_region.id}", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + # Check that the dataset only has the second region + get_response = client.get( + "/regions/", + headers=normal_user_headers, + params={"dataset_id": dataset_id}, + ) + assert get_response.status_code == status.HTTP_200_OK + + region_list = get_response.json() + assert len(region_list) == 1 + assert region_list[0]["name"] == second_region["name"] + assert region_list[0]["id"] == second_region["id"] diff --git a/tests/api/test_ts_capacity_fix.py b/tests/api/test_ts_capacity_fix.py index 1ab803e..934160e 100644 --- a/tests/api/test_ts_capacity_fix.py +++ b/tests/api/test_ts_capacity_fix.py @@ -4,27 +4,14 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from ensysmod.schemas import CapacityFixCreate -from tests.utils import data_generator as data_gen -from tests.utils.utils import random_float_numbers - - -def get_random_fix_capacity_create(db: Session) -> CapacityFixCreate: - source = data_gen.fixed_existing_energy_sink(db) - region = data_gen.fixed_existing_region(db) - return CapacityFixCreate( - ref_dataset=region.ref_dataset, - component=source.component.name, - region=region.name, - fix_capacities=random_float_numbers(8760) - ) +from tests.utils import data_generator def test_create_fix_capacity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ Test creating a fix capacity time series. """ - create_request = get_random_fix_capacity_create(db) + create_request = data_generator.get_random_fix_capacity_create(db) response = client.post("/fix-capacities/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK diff --git a/tests/api/test_ts_capacity_max.py b/tests/api/test_ts_capacity_max.py index 71ac320..a7c5a29 100644 --- a/tests/api/test_ts_capacity_max.py +++ b/tests/api/test_ts_capacity_max.py @@ -4,27 +4,14 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from ensysmod.schemas import CapacityMaxCreate -from tests.utils import data_generator as data_gen -from tests.utils.utils import random_float_numbers - - -def get_random_max_capacity_create(db: Session) -> CapacityMaxCreate: - source = data_gen.fixed_existing_energy_sink(db) - region = data_gen.fixed_existing_region(db) - return CapacityMaxCreate( - ref_dataset=region.ref_dataset, - component=source.component.name, - region=region.name, - max_capacities=random_float_numbers(8760) - ) +from tests.utils import data_generator def test_create_max_capacity(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ Test creating a max capacity time series. """ - create_request = get_random_max_capacity_create(db) + create_request = data_generator.get_random_max_capacity_create(db) response = client.post("/max-capacities/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK diff --git a/tests/api/test_ts_operation_rate_fix.py b/tests/api/test_ts_operation_rate_fix.py index 9d6a21f..9202fc5 100644 --- a/tests/api/test_ts_operation_rate_fix.py +++ b/tests/api/test_ts_operation_rate_fix.py @@ -4,27 +4,14 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from ensysmod.schemas import OperationRateFixCreate -from tests.utils import data_generator as data_gen -from tests.utils.utils import random_float_numbers - - -def get_random_fix_operation_rate_create(db: Session) -> OperationRateFixCreate: - source = data_gen.fixed_existing_energy_sink(db) - region = data_gen.fixed_existing_region(db) - return OperationRateFixCreate( - ref_dataset=region.ref_dataset, - component=source.component.name, - region=region.name, - fix_operation_rates=random_float_numbers(8760) - ) +from tests.utils import data_generator def test_create_fix_operation_rate(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ Test creating a fix operation rate time series. """ - create_request = get_random_fix_operation_rate_create(db) + create_request = data_generator.get_random_fix_operation_rate_create(db) response = client.post("/fix-operation-rates/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK diff --git a/tests/api/test_ts_operation_rate_max.py b/tests/api/test_ts_operation_rate_max.py index 2f4fdec..08b287e 100644 --- a/tests/api/test_ts_operation_rate_max.py +++ b/tests/api/test_ts_operation_rate_max.py @@ -4,27 +4,14 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from ensysmod.schemas import OperationRateMaxCreate -from tests.utils import data_generator as data_gen -from tests.utils.utils import random_float_numbers - - -def get_random_max_operation_rate_create(db: Session) -> OperationRateMaxCreate: - source = data_gen.fixed_existing_energy_sink(db) - region = data_gen.fixed_existing_region(db) - return OperationRateMaxCreate( - ref_dataset=region.ref_dataset, - component=source.component.name, - region=region.name, - max_operation_rates=random_float_numbers(8760) - ) +from tests.utils import data_generator def test_create_max_operation_rate(client: TestClient, normal_user_headers: Dict[str, str], db: Session): """ Test creating a max operation rate time series. """ - create_request = get_random_max_operation_rate_create(db) + create_request = data_generator.get_random_max_operation_rate_create(db) response = client.post("/max-operation-rates/", headers=normal_user_headers, data=create_request.json()) assert response.status_code == status.HTTP_200_OK diff --git a/tests/api/test_users.py b/tests/api/test_users.py new file mode 100644 index 0000000..02d8a72 --- /dev/null +++ b/tests/api/test_users.py @@ -0,0 +1,26 @@ +from typing import Dict + +from fastapi import status +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from tests.utils.utils import clear_database, create_random_user + + +def test_get_all_users(client: TestClient, normal_user_headers: Dict[str, str], db: Session): + """ + Test retrieving all users. + """ + clear_database(db) + user1 = create_random_user(db) + user2 = create_random_user(db) + + response = client.get("/users/", headers=normal_user_headers) + assert response.status_code == status.HTTP_200_OK + + users_list = response.json() + assert len(users_list) == 2 + assert users_list[0]["username"] == user1.username + assert users_list[0]["id"] == user1.id + assert users_list[1]["username"] == user2.username + assert users_list[1]["id"] == user2.id diff --git a/tests/api/test_zip_download.py b/tests/api/test_zip_download.py deleted file mode 100644 index 744f73d..0000000 --- a/tests/api/test_zip_download.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Dict - -import pytest -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session -from tests.api.test_zip_upload import get_dataset_zip -from tests.utils import data_generator - - -@pytest.mark.parametrize("data_folder", ["1node_Example", "Multi-regional_Example"]) -def test_download_dataset(client: TestClient, db: Session, normal_user_headers: Dict[str, str], data_folder: str): - """ - Test downloading a dataset. - """ - # Create a dataset - dataset = data_generator.random_existing_dataset(db) - - # Upload a zip file - zip_file_path = get_dataset_zip(data_folder) - - response = client.post( - f"/datasets/{dataset.id}/upload", - headers=normal_user_headers, - files={"file": ("dataset.zip", open(zip_file_path, "rb"), "application/zip")}, - ) - assert response.status_code == 200 - - response = client.get(f"/datasets/{dataset.id}/download", headers=normal_user_headers) - assert response.status_code == 200 - assert response.headers["Content-Type"] == "application/zip" diff --git a/tests/api/test_zip_upload.py b/tests/api/test_zip_upload.py deleted file mode 100644 index c029969..0000000 --- a/tests/api/test_zip_upload.py +++ /dev/null @@ -1,57 +0,0 @@ -import os -import tempfile -import zipfile -from typing import Dict -from zipfile import ZipFile - -import pytest -from fastapi.testclient import TestClient -from sqlalchemy.orm import Session -from tests.utils import data_generator - - -def get_dataset_zip(folder_name: str) -> str: - """ - Creates a zip archive from folder structure ../../examples/datasets/ - """ - # Create a temporary directory - temp_dir = tempfile.mkdtemp() - # create a zip file from the directory - zip_file_path = os.path.join(temp_dir, f"{folder_name}.zip") - with zipfile.ZipFile(zip_file_path, 'w') as zip_file: - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) - print(f"Project root: {project_root}") - for root, dirs, files in os.walk(f"{project_root}/examples/datasets/{folder_name}/"): - acr_path = os.path.relpath(root, f"{project_root}/examples/datasets/{folder_name}/") - zip_file.write(root, acr_path) - for file in files: - zip_file.write(os.path.join(root, file), arcname=os.path.join(acr_path, file)) - return zip_file_path - - -@pytest.mark.parametrize("data_folder", ["1node_Example", "Multi-regional_Example"]) -def test_upload_dataset(client: TestClient, db: Session, normal_user_headers: Dict[str, str], data_folder: str): - """ - Test uploading a dataset. - """ - # Create a dataset - dataset = data_generator.random_existing_dataset(db) - - # Upload a zip file - zip_file_path = get_dataset_zip(data_folder) - - # print all the contents of the zip file - print(f"Zip file contents of {zip_file_path}:") - with ZipFile(zip_file_path, 'r') as zip_file: - for file in zip_file.namelist(): - print(file) - - response = client.post( - f"/datasets/{dataset.id}/upload", - headers=normal_user_headers, - files={"file": ("dataset.zip", open(zip_file_path, "rb"), "application/zip")}, - ) - print(response.text) - assert response.status_code == 200 - - # TODO Check that the dataset has been updated diff --git a/tests/utils/data_generator/__init__.py b/tests/utils/data_generator/__init__.py index e5280e4..31e0352 100644 --- a/tests/utils/data_generator/__init__.py +++ b/tests/utils/data_generator/__init__.py @@ -1,14 +1,55 @@ -from .datasets import random_existing_dataset, random_dataset_create, fixed_existing_dataset -from .energy_commodities import random_existing_energy_commodity, fixed_existing_energy_commodity, \ - random_energy_commodity_create -from .energy_conversions import random_existing_energy_conversion, fixed_existing_energy_conversion, \ - random_energy_conversion_create -from .energy_sinks import random_existing_energy_sink, fixed_existing_energy_sink, random_energy_sink_create -from .energy_sources import random_existing_energy_source, fixed_existing_energy_source, random_energy_source_create -from .energy_storages import random_existing_energy_storage, fixed_existing_energy_storage, \ - random_energy_storage_create -from .energy_transmissions import random_existing_energy_transmission, fixed_existing_energy_transmission, \ - random_energy_transmission_create -from .regions import random_existing_region, fixed_existing_region, random_region_create, \ - fixed_alternative_existing_region -from .energy_models import random_energy_model_create, random_existing_energy_model +from .datasets import ( + create_example_dataset, + fixed_existing_dataset, + get_dataset_zip, + random_dataset_create, + random_existing_dataset, +) +from .energy_commodities import ( + fixed_existing_energy_commodity, + random_energy_commodity_create, + random_existing_energy_commodity, +) +from .energy_conversions import ( + fixed_existing_energy_conversion, + random_energy_conversion_create, + random_existing_energy_conversion, +) +from .energy_models import ( + create_example_model, + fixed_existing_energy_model, + random_energy_model_create, + random_existing_energy_model, +) +from .energy_sinks import ( + fixed_existing_energy_sink, + random_energy_sink_create, + random_existing_energy_sink, +) +from .energy_sources import ( + fixed_existing_energy_source, + random_energy_source_create, + random_existing_energy_source, +) +from .energy_storages import ( + fixed_existing_energy_storage, + random_energy_storage_create, + random_existing_energy_storage, +) +from .energy_transmissions import ( + fixed_existing_energy_transmission, + random_energy_transmission_create, + random_existing_energy_transmission, +) +from .regions import ( + fixed_alternative_existing_region, + fixed_existing_region, + random_existing_region, + random_region_create, +) +from .ts import ( + get_random_fix_capacity_create, + get_random_fix_operation_rate_create, + get_random_max_capacity_create, + get_random_max_operation_rate_create, +) diff --git a/tests/utils/data_generator/datasets.py b/tests/utils/data_generator/datasets.py index 52827de..32b3792 100644 --- a/tests/utils/data_generator/datasets.py +++ b/tests/utils/data_generator/datasets.py @@ -1,9 +1,14 @@ +import os +import tempfile +from zipfile import ZipFile + from sqlalchemy.orm import Session from ensysmod import crud +from ensysmod.core.file_upload import process_dataset_zip_archive from ensysmod.model import Dataset -from ensysmod.schemas import DatasetCreate -from tests.utils.utils import random_lower_string, create_random_user +from ensysmod.schemas import DatasetCreate, FileStatus +from tests.utils.utils import create_random_user, random_lower_string def random_dataset_create() -> DatasetCreate: @@ -50,3 +55,50 @@ def fixed_existing_dataset(db: Session) -> Dataset: if dataset is None: dataset = crud.dataset.create(db=db, obj_in=create_request) return dataset + + +def get_dataset_zip(folder_name: str) -> str: + """ + Create a zip archive from folder structure /examples/datasets/ + """ + # Create a temporary directory + temp_dir = tempfile.mkdtemp() + # create a zip file from the directory + zip_file_path = os.path.join(temp_dir, f"{folder_name}.zip") + with ZipFile(zip_file_path, 'w') as zip_file: + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) + print(f"Project root: {project_root}") + for root, dirs, files in os.walk(f"{project_root}/examples/datasets/{folder_name}/"): + acr_path = os.path.relpath(root, f"{project_root}/examples/datasets/{folder_name}/") + zip_file.write(root, acr_path) + for file in files: + zip_file.write(os.path.join(root, file), arcname=os.path.join(acr_path, file)) + return zip_file_path + + +def create_example_dataset(db: Session, data_folder: str): + """ + Create dataset from examples folder. + """ + # Create dataset + dataset_name = f"DS-{data_folder}" + random_lower_string() + dataset_description = "DS desc " + random_lower_string() + create_request = DatasetCreate( + name=dataset_name, + description=dataset_description, + hours_per_time_step=1, + number_of_time_steps=8760, + cost_unit='1e9 Euro', + length_unit='km', + ref_created_by=1 + ) + dataset = crud.dataset.create(db=db, obj_in=create_request) + + # Zip and upload the example dataset from data_folder + zip_file_path = get_dataset_zip(data_folder) + + with ZipFile(zip_file_path, 'r') as zip_archive: + result = process_dataset_zip_archive(zip_archive, dataset.id, db) + assert result.status == FileStatus.OK + + return dataset diff --git a/tests/utils/data_generator/energy_commodities.py b/tests/utils/data_generator/energy_commodities.py index 754cd48..ecf9b23 100644 --- a/tests/utils/data_generator/energy_commodities.py +++ b/tests/utils/data_generator/energy_commodities.py @@ -3,7 +3,10 @@ from ensysmod import crud from ensysmod.model import EnergyCommodity from ensysmod.schemas import EnergyCommodityCreate -from tests.utils.data_generator.datasets import fixed_existing_dataset, random_existing_dataset +from tests.utils.data_generator.datasets import ( + fixed_existing_dataset, + random_existing_dataset, +) from tests.utils.utils import random_lower_string diff --git a/tests/utils/data_generator/energy_conversions.py b/tests/utils/data_generator/energy_conversions.py index 20cd34d..75aac7c 100644 --- a/tests/utils/data_generator/energy_conversions.py +++ b/tests/utils/data_generator/energy_conversions.py @@ -4,7 +4,9 @@ from ensysmod.model import EnergyConversion from ensysmod.schemas import EnergyConversionCreate, EnergyConversionFactorCreate from tests.utils.data_generator.datasets import fixed_existing_dataset -from tests.utils.data_generator.energy_commodities import fixed_existing_energy_commodity +from tests.utils.data_generator.energy_commodities import ( + fixed_existing_energy_commodity, +) from tests.utils.utils import random_lower_string diff --git a/tests/utils/data_generator/energy_models.py b/tests/utils/data_generator/energy_models.py index fe6931a..70d4e24 100644 --- a/tests/utils/data_generator/energy_models.py +++ b/tests/utils/data_generator/energy_models.py @@ -1,7 +1,4 @@ from sqlalchemy.orm import Session -from tests.utils.data_generator import fixed_existing_energy_source -from tests.utils.data_generator.datasets import fixed_existing_dataset -from tests.utils.utils import random_lower_string from ensysmod import crud from ensysmod.model import EnergyModel @@ -10,6 +7,12 @@ EnergyModelOptimizationCreate, EnergyModelOverrideCreate, ) +from tests.utils.data_generator.datasets import ( + create_example_dataset, + fixed_existing_dataset, +) +from tests.utils.data_generator.energy_sources import fixed_existing_energy_source +from tests.utils.utils import random_lower_string def random_energy_model_create(db: Session) -> EnergyModelCreate: @@ -74,3 +77,21 @@ def fixed_existing_energy_model(db: Session) -> EnergyModel: if model is None: return crud.energy_model.create(db=db, obj_in=create_request) return model + + +def create_example_model(db: Session, data_folder: str): + """ + Create model from example dataset. + """ + dataset = create_example_dataset(db, data_folder) + + create_request = EnergyModelCreate( + name=f"Example_Model-{data_folder}-" + random_lower_string(), + ref_dataset=dataset.id, + description="Example_Model description", + override_parameters=None, + optimization_parameters=None + ) + model = crud.energy_model.create(db=db, obj_in=create_request) + + return model diff --git a/tests/utils/data_generator/energy_sinks.py b/tests/utils/data_generator/energy_sinks.py index 07e3dd5..7621547 100644 --- a/tests/utils/data_generator/energy_sinks.py +++ b/tests/utils/data_generator/energy_sinks.py @@ -1,14 +1,14 @@ from sqlalchemy.orm import Session + +from ensysmod import crud +from ensysmod.model import EnergySink +from ensysmod.schemas import EnergySinkCreate from tests.utils.data_generator.datasets import fixed_existing_dataset from tests.utils.data_generator.energy_commodities import ( fixed_existing_energy_commodity, ) from tests.utils.utils import random_lower_string -from ensysmod import crud -from ensysmod.model import EnergySink -from ensysmod.schemas import EnergySinkCreate - def random_energy_sink_create(db: Session) -> EnergySinkCreate: dataset = fixed_existing_dataset(db) diff --git a/tests/utils/data_generator/energy_sources.py b/tests/utils/data_generator/energy_sources.py index 1f0e734..9e884fc 100644 --- a/tests/utils/data_generator/energy_sources.py +++ b/tests/utils/data_generator/energy_sources.py @@ -3,7 +3,10 @@ from ensysmod import crud from ensysmod.model import EnergySource from ensysmod.schemas import EnergySourceCreate -from tests.utils.data_generator import fixed_existing_dataset, fixed_existing_energy_commodity +from tests.utils.data_generator import ( + fixed_existing_dataset, + fixed_existing_energy_commodity, +) from tests.utils.utils import random_lower_string diff --git a/tests/utils/data_generator/energy_storages.py b/tests/utils/data_generator/energy_storages.py index 28a1c8e..f50629c 100644 --- a/tests/utils/data_generator/energy_storages.py +++ b/tests/utils/data_generator/energy_storages.py @@ -3,7 +3,10 @@ from ensysmod import crud from ensysmod.model import EnergyStorage from ensysmod.schemas import EnergyStorageCreate -from tests.utils.data_generator import fixed_existing_dataset, fixed_existing_energy_commodity +from tests.utils.data_generator import ( + fixed_existing_dataset, + fixed_existing_energy_commodity, +) from tests.utils.utils import random_lower_string diff --git a/tests/utils/data_generator/energy_transmissions.py b/tests/utils/data_generator/energy_transmissions.py index e85599a..8073fef 100644 --- a/tests/utils/data_generator/energy_transmissions.py +++ b/tests/utils/data_generator/energy_transmissions.py @@ -3,10 +3,18 @@ from ensysmod import crud from ensysmod.model import EnergyTransmission from ensysmod.schemas import EnergyTransmissionCreate -from ensysmod.schemas.energy_transmission_distance import EnergyTransmissionDistanceCreate -from tests.utils.data_generator import fixed_existing_dataset, fixed_existing_energy_commodity -from tests.utils.data_generator.regions import fixed_existing_region, fixed_alternative_existing_region, \ - fixed_alternative_alternative_existing_region +from ensysmod.schemas.energy_transmission_distance import ( + EnergyTransmissionDistanceCreate, +) +from tests.utils.data_generator import ( + fixed_existing_dataset, + fixed_existing_energy_commodity, +) +from tests.utils.data_generator.regions import ( + fixed_alternative_alternative_existing_region, + fixed_alternative_existing_region, + fixed_existing_region, +) from tests.utils.utils import random_lower_string diff --git a/tests/utils/data_generator/regions.py b/tests/utils/data_generator/regions.py index 78c6313..17c72d6 100644 --- a/tests/utils/data_generator/regions.py +++ b/tests/utils/data_generator/regions.py @@ -3,7 +3,7 @@ from ensysmod import crud from ensysmod.model import Region from ensysmod.schemas import RegionCreate -from tests.utils.data_generator import random_existing_dataset, fixed_existing_dataset +from tests.utils.data_generator import fixed_existing_dataset, random_existing_dataset from tests.utils.utils import random_lower_string diff --git a/tests/utils/data_generator/ts.py b/tests/utils/data_generator/ts.py new file mode 100644 index 0000000..b442f43 --- /dev/null +++ b/tests/utils/data_generator/ts.py @@ -0,0 +1,54 @@ +from sqlalchemy.orm import Session + +from ensysmod.schemas import ( + CapacityFixCreate, + CapacityMaxCreate, + OperationRateFixCreate, + OperationRateMaxCreate, +) +from tests.utils import data_generator +from tests.utils.utils import random_float_numbers + + +def get_random_fix_capacity_create(db: Session) -> CapacityFixCreate: + source = data_generator.fixed_existing_energy_sink(db) + region = data_generator.fixed_existing_region(db) + return CapacityFixCreate( + ref_dataset=region.ref_dataset, + component=source.component.name, + region=region.name, + fix_capacities=random_float_numbers(8760) + ) + + +def get_random_max_capacity_create(db: Session) -> CapacityMaxCreate: + source = data_generator.fixed_existing_energy_sink(db) + region = data_generator.fixed_existing_region(db) + return CapacityMaxCreate( + ref_dataset=region.ref_dataset, + component=source.component.name, + region=region.name, + max_capacities=random_float_numbers(8760) + ) + + +def get_random_fix_operation_rate_create(db: Session) -> OperationRateFixCreate: + source = data_generator.fixed_existing_energy_sink(db) + region = data_generator.fixed_existing_region(db) + return OperationRateFixCreate( + ref_dataset=region.ref_dataset, + component=source.component.name, + region=region.name, + fix_operation_rates=random_float_numbers(8760) + ) + + +def get_random_max_operation_rate_create(db: Session) -> OperationRateMaxCreate: + source = data_generator.fixed_existing_energy_sink(db) + region = data_generator.fixed_existing_region(db) + return OperationRateMaxCreate( + ref_dataset=region.ref_dataset, + component=source.component.name, + region=region.name, + max_operation_rates=random_float_numbers(8760) + ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 2e3759d..1d53d8e 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -3,10 +3,32 @@ from typing import Dict, List from fastapi.testclient import TestClient +from sqlalchemy import delete from sqlalchemy.orm import Session from ensysmod import crud -from ensysmod.model import User +from ensysmod.model import ( + CapacityFix, + CapacityMax, + Dataset, + DatasetPermission, + EnergyCommodity, + EnergyComponent, + EnergyConversion, + EnergyConversionFactor, + EnergyModel, + EnergyModelOptimization, + EnergyModelOverride, + EnergySink, + EnergySource, + EnergyStorage, + EnergyTransmission, + EnergyTransmissionDistance, + OperationRateFix, + OperationRateMax, + Region, + User, +) from ensysmod.schemas import UserCreate, UserUpdate @@ -55,3 +77,33 @@ def authentication_token_from_username( crud.user.update(db, db_obj=user, obj_in=user_in_update) return user_authentication_headers(client=client, username=username, password=password) + + +def clear_database(db: Session): + """ + Clear entries in the database but keep the database structure intact. + """ + tables = [ + CapacityFix, + CapacityMax, + Dataset, + DatasetPermission, + EnergyCommodity, + EnergyComponent, + EnergyConversion, + EnergyConversionFactor, + EnergyModel, + EnergyModelOptimization, + EnergyModelOverride, + EnergySink, + EnergySource, + EnergyStorage, + EnergyTransmission, + EnergyTransmissionDistance, + OperationRateFix, + OperationRateMax, + Region, + User, + ] + for table in tables: + db.execute(delete(table)) diff --git a/tests/validators/test_optimization_parameters_validator.py b/tests/validators/test_optimization_parameters_validator.py index 28efe28..79c824c 100644 --- a/tests/validators/test_optimization_parameters_validator.py +++ b/tests/validators/test_optimization_parameters_validator.py @@ -147,7 +147,9 @@ def test_error_invalid_timeframe_parameter(schema: Type[BaseModel], data: Dict[s assert len(exc_info.value.errors()) == 1 assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "The parameters must satisfy the equation: (end_year - start_year) = number_of_steps * years_per_step." + assert ( + exc_info.value.errors()[0]["msg"] == "The parameters must satisfy the equation: (end_year - start_year) = number_of_steps * years_per_step." + ) assert exc_info.value.errors()[0]["type"] == "value_error" @@ -217,5 +219,8 @@ def test_error_invalid_CO2_reduction_target_length(schema: Type[BaseModel], data assert len(exc_info.value.errors()) == 1 assert exc_info.value.errors()[0]["loc"] == ("__root__",) - assert exc_info.value.errors()[0]["msg"] == "The number of values given in CO2_reduction_targets must match the number of optimization runs. Expected: 4, given: 3." + assert ( + exc_info.value.errors()[0]["msg"] + == "The number of values given in CO2_reduction_targets must match the number of optimization runs. Expected: 4, given: 3." + ) assert exc_info.value.errors()[0]["type"] == "value_error"