Skip to content

Commit

Permalink
test(export): correct unit tests for `VariantStudyService.generate_ta…
Browse files Browse the repository at this point in the history
…sk` method
  • Loading branch information
skamril committed Aug 22, 2023
1 parent 518921e commit cb81235
Show file tree
Hide file tree
Showing 28 changed files with 1,452 additions and 375 deletions.
18 changes: 10 additions & 8 deletions antarest/study/storage/abstract_storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,23 @@ def export_study(
self, metadata: T, target: Path, outputs: bool = True
) -> Path:
"""
Export and compresses study inside zip
Args:
metadata: study
target: path of the file to export to
outputs: ask to integrated output folder inside exportation
Export and compress the study inside a ZIP file.
Returns: zip file with study files compressed inside
Args:
metadata: Study metadata object.
target: Path of the file to export to.
outputs: Flag to indicate whether to include the output folder inside the exportation.
Returns:
The ZIP file containing the study files compressed inside.
"""
path_study = Path(metadata.path)
with tempfile.TemporaryDirectory(
dir=self.config.storage.tmp_dir
) as tmpdir:
logger.info(f"Exporting study {metadata.id} to tmp path {tmpdir}")
assert_this(target.name.endswith(".zip"))
logger.info(
f"Exporting study {metadata.id} to temporary path {tmpdir}"
)
tmp_study_path = Path(tmpdir) / "tmp_copy"
self.export_study_flat(metadata, tmp_study_path, outputs)
stopwatch = StopWatch()
Expand Down
41 changes: 38 additions & 3 deletions antarest/study/storage/storage_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""
This module provides the ``StudyStorageService`` class, which acts as a dispatcher for study storage services.
It determines the appropriate study storage service based on the type of study provided.
"""

from typing import Union

from antarest.core.exceptions import StudyTypeUnsupported
from antarest.study.common.studystorage import IStudyStorageService
from antarest.study.model import Study, RawStudy
from antarest.study.model import RawStudy, Study
from antarest.study.storage.rawstudy.raw_study_service import RawStudyService
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy
from antarest.study.storage.variantstudy.variant_study_service import (
Expand All @@ -11,17 +16,47 @@


class StudyStorageService:
"""
A class that acts as a dispatcher for study storage services.
This class determines the appropriate study storage service based on the type of study provided.
It delegates the study storage operations to the corresponding service.
Attributes:
raw_study_service: The service for managing raw studies.
variant_study_service: The service for managing variant studies.
"""

def __init__(
self,
raw_study_service: RawStudyService,
variante_study_service: VariantStudyService,
variant_study_service: VariantStudyService,
):
"""
Initialize the ``StudyStorageService`` with raw and variant study services.
Args:
raw_study_service: The service for managing raw studies.
variant_study_service: The service for managing variant studies.
"""
self.raw_study_service = raw_study_service
self.variant_study_service = variante_study_service
self.variant_study_service = variant_study_service

def get_storage(
self, study: Study
) -> IStudyStorageService[Union[RawStudy, VariantStudy]]:
"""
Get the appropriate study storage service based on the type of study.
Args:
study: The study object for which the storage service is required.
Returns:
The study storage service associated with the study type.
Raises:
StudyTypeUnsupported: If the study type is not supported by the available storage services.
"""
if isinstance(study, RawStudy):
return self.raw_study_service
elif isinstance(study, VariantStudy):
Expand Down
41 changes: 22 additions & 19 deletions antarest/study/storage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,26 +390,29 @@ def export_study_flat(

shutil.copytree(src=path_study, dst=dest, ignore=ignore_patterns)

if outputs and output_src_path.is_dir():
if output_dest_path.is_dir():
shutil.rmtree(output_dest_path)
if output_list_filter is not None:
os.mkdir(output_dest_path)
for output in output_list_filter:
zip_path = output_src_path / f"{output}.zip"
if zip_path.exists():
with ZipFile(zip_path) as zf:
zf.extractall(output_dest_path / output)
else:
shutil.copytree(
src=output_src_path / output,
dst=output_dest_path / output,
)
else:
shutil.copytree(
src=output_src_path,
dst=output_dest_path,
if outputs and output_src_path.exists():
if output_list_filter is None:
# Retrieve all directories or ZIP files without duplicates
output_list_filter = list(
{
f.with_suffix("").name
for f in output_src_path.iterdir()
if f.is_dir() or f.suffix == ".zip"
}
)
# Copy each folder or uncompress each ZIP file to the destination dir.
shutil.rmtree(output_dest_path, ignore_errors=True)
output_dest_path.mkdir()
for output in output_list_filter:
zip_path = output_src_path / f"{output}.zip"
if zip_path.exists():
with ZipFile(zip_path) as zf:
zf.extractall(output_dest_path / output)
else:
shutil.copytree(
src=output_src_path / output,
dst=output_dest_path / output,
)

stop_time = time.time()
duration = "{:.3f}".format(stop_time - start_time)
Expand Down
102 changes: 7 additions & 95 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,19 @@
import time
from datetime import datetime, timedelta, timezone
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Dict, List, cast

import numpy as np
import numpy.typing as npt
import pytest
from antarest.core.model import SUB_JSON
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db
from antarest.dbmodel import Base
from sqlalchemy import create_engine # type: ignore

# noinspection PyUnresolvedReferences
from tests.conftest_db import *

# noinspection PyUnresolvedReferences
from tests.conftest_services import *

# fmt: off
HERE = Path(__file__).parent.resolve()
PROJECT_DIR = next(iter(p for p in HERE.parents if p.joinpath("antarest").exists()))
# fmt: on


@pytest.fixture
@pytest.fixture(scope="session")
def project_path() -> Path:
return PROJECT_DIR


def with_db_context(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
engine = create_engine("sqlite:///:memory:", echo=False)
Base.metadata.create_all(engine)
# noinspection SpellCheckingInspection
DBSessionMiddleware(
None,
custom_engine=engine,
session_args={"autocommit": False, "autoflush": False},
)
with db():
return f(*args, **kwargs)

return wrapper


def _assert_dict(a: Dict[str, Any], b: Dict[str, Any]) -> None:
if a.keys() != b.keys():
raise AssertionError(
f"study level has not the same keys {a.keys()} != {b.keys()}"
)
for k, v in a.items():
assert_study(v, b[k])


def _assert_list(a: List[Any], b: List[Any]) -> None:
for i, j in zip(a, b):
assert_study(i, j)


def _assert_pointer_path(a: str, b: str) -> None:
# pointer is like studyfile://study-id/a/b/c
# we should compare a/b/c only
if a.split("/")[4:] != b.split("/")[4:]:
raise AssertionError(f"element in study not the same {a} != {b}")


def _assert_others(a: Any, b: Any) -> None:
if a != b:
raise AssertionError(f"element in study not the same {a} != {b}")


def _assert_array(
a: npt.NDArray[np.float64],
b: npt.NDArray[np.float64],
) -> None:
if not (a == b).all():
raise AssertionError(f"element in study not the same {a} != {b}")


def assert_study(a: SUB_JSON, b: SUB_JSON) -> None:
if isinstance(a, dict) and isinstance(b, dict):
_assert_dict(a, b)
elif isinstance(a, list) and isinstance(b, list):
_assert_list(a, b)
elif (
isinstance(a, str)
and isinstance(b, str)
and "studyfile://" in a
and "studyfile://" in b
):
_assert_pointer_path(a, b)
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
_assert_array(a, b)
elif isinstance(a, np.ndarray) and isinstance(b, list):
_assert_list(cast(List[float], a.tolist()), b)
elif isinstance(a, list) and isinstance(b, np.ndarray):
_assert_list(a, cast(List[float], b.tolist()))
else:
_assert_others(a, b)


def auto_retry_assert(
predicate: Callable[..., bool], timeout: int = 2, delay: float = 0.2
) -> None:
threshold = datetime.now(timezone.utc) + timedelta(seconds=timeout)
while datetime.now(timezone.utc) < threshold:
if predicate():
return
time.sleep(delay)
raise AssertionError()
64 changes: 64 additions & 0 deletions tests/conftest_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import contextlib
from typing import Any, Generator

import pytest
from sqlalchemy import create_engine # type: ignore
from sqlalchemy.orm import sessionmaker

from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
from antarest.dbmodel import Base

__all__ = ("db_engine_fixture", "db_session_fixture", "db_middleware_fixture")


@pytest.fixture(name="db_engine")
def db_engine_fixture() -> Generator[Any, None, None]:
"""
Fixture that creates an in-memory SQLite database engine for testing.
Yields:
An instance of the created SQLite database engine.
"""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
yield engine
engine.dispose()


@pytest.fixture(name="db_session")
def db_session_fixture(db_engine) -> Generator:
"""
Fixture that creates a database session for testing purposes.
This fixture uses the provided db engine fixture to create a session maker,
which in turn generates a new database session bound to the specified engine.
Args:
db_engine: The database engine instance provided by the db_engine fixture.
Yields:
A new SQLAlchemy session object for database operations.
"""
make_session = sessionmaker(bind=db_engine)
with contextlib.closing(make_session()) as session:
yield session


@pytest.fixture(name="db_middleware", autouse=True)
def db_middleware_fixture(
db_engine: Any,
) -> Generator[DBSessionMiddleware, None, None]:
"""
Fixture that sets up a database session middleware with custom engine settings.
Args:
db_engine: The database engine instance created by the db_engine fixture.
Yields:
An instance of the configured DBSessionMiddleware.
"""
yield DBSessionMiddleware(
None,
custom_engine=db_engine,
session_args={"autocommit": False, "autoflush": False},
)
Loading

0 comments on commit cb81235

Please sign in to comment.