diff --git a/antarest/study/storage/abstract_storage_service.py b/antarest/study/storage/abstract_storage_service.py index 404019a289..d3bbeedfd8 100644 --- a/antarest/study/storage/abstract_storage_service.py +++ b/antarest/study/storage/abstract_storage_service.py @@ -256,21 +256,23 @@ def export_study( self, metadata: T, target: Path, outputs: bool = True ) -> Path: """ - Export and compresses study inside zip - Args: - metadata: study - target: path of the file to export to - outputs: ask to integrated output folder inside exportation + Export and compress the study inside a ZIP file. - Returns: zip file with study files compressed inside + Args: + metadata: Study metadata object. + target: Path of the file to export to. + outputs: Flag to indicate whether to include the output folder inside the exportation. + Returns: + The ZIP file containing the study files compressed inside. """ path_study = Path(metadata.path) with tempfile.TemporaryDirectory( dir=self.config.storage.tmp_dir ) as tmpdir: - logger.info(f"Exporting study {metadata.id} to tmp path {tmpdir}") - assert_this(target.name.endswith(".zip")) + logger.info( + f"Exporting study {metadata.id} to temporary path {tmpdir}" + ) tmp_study_path = Path(tmpdir) / "tmp_copy" self.export_study_flat(metadata, tmp_study_path, outputs) stopwatch = StopWatch() diff --git a/antarest/study/storage/storage_service.py b/antarest/study/storage/storage_service.py index 7602e7f51f..0dfb33d565 100644 --- a/antarest/study/storage/storage_service.py +++ b/antarest/study/storage/storage_service.py @@ -1,8 +1,13 @@ +""" +This module provides the ``StudyStorageService`` class, which acts as a dispatcher for study storage services. +It determines the appropriate study storage service based on the type of study provided. +""" + from typing import Union from antarest.core.exceptions import StudyTypeUnsupported from antarest.study.common.studystorage import IStudyStorageService -from antarest.study.model import Study, RawStudy +from antarest.study.model import RawStudy, Study from antarest.study.storage.rawstudy.raw_study_service import RawStudyService from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from antarest.study.storage.variantstudy.variant_study_service import ( @@ -11,17 +16,47 @@ class StudyStorageService: + """ + A class that acts as a dispatcher for study storage services. + + This class determines the appropriate study storage service based on the type of study provided. + It delegates the study storage operations to the corresponding service. + + Attributes: + raw_study_service: The service for managing raw studies. + variant_study_service: The service for managing variant studies. + """ + def __init__( self, raw_study_service: RawStudyService, - variante_study_service: VariantStudyService, + variant_study_service: VariantStudyService, ): + """ + Initialize the ``StudyStorageService`` with raw and variant study services. + + Args: + raw_study_service: The service for managing raw studies. + variant_study_service: The service for managing variant studies. + """ self.raw_study_service = raw_study_service - self.variant_study_service = variante_study_service + self.variant_study_service = variant_study_service def get_storage( self, study: Study ) -> IStudyStorageService[Union[RawStudy, VariantStudy]]: + """ + Get the appropriate study storage service based on the type of study. + + Args: + study: The study object for which the storage service is required. + + Returns: + The study storage service associated with the study type. + + Raises: + StudyTypeUnsupported: If the study type is not supported by the available storage services. + """ if isinstance(study, RawStudy): return self.raw_study_service elif isinstance(study, VariantStudy): diff --git a/antarest/study/storage/utils.py b/antarest/study/storage/utils.py index 3c9ee5d901..117bf38683 100644 --- a/antarest/study/storage/utils.py +++ b/antarest/study/storage/utils.py @@ -390,26 +390,29 @@ def export_study_flat( shutil.copytree(src=path_study, dst=dest, ignore=ignore_patterns) - if outputs and output_src_path.is_dir(): - if output_dest_path.is_dir(): - shutil.rmtree(output_dest_path) - if output_list_filter is not None: - os.mkdir(output_dest_path) - for output in output_list_filter: - zip_path = output_src_path / f"{output}.zip" - if zip_path.exists(): - with ZipFile(zip_path) as zf: - zf.extractall(output_dest_path / output) - else: - shutil.copytree( - src=output_src_path / output, - dst=output_dest_path / output, - ) - else: - shutil.copytree( - src=output_src_path, - dst=output_dest_path, + if outputs and output_src_path.exists(): + if output_list_filter is None: + # Retrieve all directories or ZIP files without duplicates + output_list_filter = list( + { + f.with_suffix("").name + for f in output_src_path.iterdir() + if f.is_dir() or f.suffix == ".zip" + } ) + # Copy each folder or uncompress each ZIP file to the destination dir. + shutil.rmtree(output_dest_path, ignore_errors=True) + output_dest_path.mkdir() + for output in output_list_filter: + zip_path = output_src_path / f"{output}.zip" + if zip_path.exists(): + with ZipFile(zip_path) as zf: + zf.extractall(output_dest_path / output) + else: + shutil.copytree( + src=output_src_path / output, + dst=output_dest_path / output, + ) stop_time = time.time() duration = "{:.3f}".format(stop_time - start_time) diff --git a/tests/conftest.py b/tests/conftest.py index 899942232e..20a93910c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,12 @@ -import time -from datetime import datetime, timedelta, timezone -from functools import wraps from pathlib import Path -from typing import Any, Callable, Dict, List, cast -import numpy as np -import numpy.typing as npt import pytest -from antarest.core.model import SUB_JSON -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db -from antarest.dbmodel import Base -from sqlalchemy import create_engine # type: ignore + +# noinspection PyUnresolvedReferences +from tests.conftest_db import * + +# noinspection PyUnresolvedReferences +from tests.conftest_services import * # fmt: off HERE = Path(__file__).parent.resolve() @@ -18,90 +14,6 @@ # fmt: on -@pytest.fixture +@pytest.fixture(scope="session") def project_path() -> Path: return PROJECT_DIR - - -def with_db_context(f: Callable[..., Any]) -> Callable[..., Any]: - @wraps(f) - def wrapper(*args: Any, **kwargs: Any) -> Any: - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - with db(): - return f(*args, **kwargs) - - return wrapper - - -def _assert_dict(a: Dict[str, Any], b: Dict[str, Any]) -> None: - if a.keys() != b.keys(): - raise AssertionError( - f"study level has not the same keys {a.keys()} != {b.keys()}" - ) - for k, v in a.items(): - assert_study(v, b[k]) - - -def _assert_list(a: List[Any], b: List[Any]) -> None: - for i, j in zip(a, b): - assert_study(i, j) - - -def _assert_pointer_path(a: str, b: str) -> None: - # pointer is like studyfile://study-id/a/b/c - # we should compare a/b/c only - if a.split("/")[4:] != b.split("/")[4:]: - raise AssertionError(f"element in study not the same {a} != {b}") - - -def _assert_others(a: Any, b: Any) -> None: - if a != b: - raise AssertionError(f"element in study not the same {a} != {b}") - - -def _assert_array( - a: npt.NDArray[np.float64], - b: npt.NDArray[np.float64], -) -> None: - if not (a == b).all(): - raise AssertionError(f"element in study not the same {a} != {b}") - - -def assert_study(a: SUB_JSON, b: SUB_JSON) -> None: - if isinstance(a, dict) and isinstance(b, dict): - _assert_dict(a, b) - elif isinstance(a, list) and isinstance(b, list): - _assert_list(a, b) - elif ( - isinstance(a, str) - and isinstance(b, str) - and "studyfile://" in a - and "studyfile://" in b - ): - _assert_pointer_path(a, b) - elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): - _assert_array(a, b) - elif isinstance(a, np.ndarray) and isinstance(b, list): - _assert_list(cast(List[float], a.tolist()), b) - elif isinstance(a, list) and isinstance(b, np.ndarray): - _assert_list(a, cast(List[float], b.tolist())) - else: - _assert_others(a, b) - - -def auto_retry_assert( - predicate: Callable[..., bool], timeout: int = 2, delay: float = 0.2 -) -> None: - threshold = datetime.now(timezone.utc) + timedelta(seconds=timeout) - while datetime.now(timezone.utc) < threshold: - if predicate(): - return - time.sleep(delay) - raise AssertionError() diff --git a/tests/conftest_db.py b/tests/conftest_db.py new file mode 100644 index 0000000000..877ca119d1 --- /dev/null +++ b/tests/conftest_db.py @@ -0,0 +1,64 @@ +import contextlib +from typing import Any, Generator + +import pytest +from sqlalchemy import create_engine # type: ignore +from sqlalchemy.orm import sessionmaker + +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.dbmodel import Base + +__all__ = ("db_engine_fixture", "db_session_fixture", "db_middleware_fixture") + + +@pytest.fixture(name="db_engine") +def db_engine_fixture() -> Generator[Any, None, None]: + """ + Fixture that creates an in-memory SQLite database engine for testing. + + Yields: + An instance of the created SQLite database engine. + """ + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + yield engine + engine.dispose() + + +@pytest.fixture(name="db_session") +def db_session_fixture(db_engine) -> Generator: + """ + Fixture that creates a database session for testing purposes. + + This fixture uses the provided db engine fixture to create a session maker, + which in turn generates a new database session bound to the specified engine. + + Args: + db_engine: The database engine instance provided by the db_engine fixture. + + Yields: + A new SQLAlchemy session object for database operations. + """ + make_session = sessionmaker(bind=db_engine) + with contextlib.closing(make_session()) as session: + yield session + + +@pytest.fixture(name="db_middleware", autouse=True) +def db_middleware_fixture( + db_engine: Any, +) -> Generator[DBSessionMiddleware, None, None]: + """ + Fixture that sets up a database session middleware with custom engine settings. + + Args: + db_engine: The database engine instance created by the db_engine fixture. + + Yields: + An instance of the configured DBSessionMiddleware. + """ + yield DBSessionMiddleware( + None, + custom_engine=db_engine, + session_args={"autocommit": False, "autoflush": False}, + ) diff --git a/tests/conftest_services.py b/tests/conftest_services.py new file mode 100644 index 0000000000..277ff8ee9c --- /dev/null +++ b/tests/conftest_services.py @@ -0,0 +1,424 @@ +""" +This module provides various pytest fixtures for unit testing the AntaREST application. + +Fixtures in this module are used to set up and provide instances of different classes and services required during testing. +""" +import datetime +import uuid +from pathlib import Path +from typing import Dict, List, Optional, Union +from unittest.mock import Mock + +import pytest +from antarest.core.config import Config, StorageConfig, WorkspaceConfig +from antarest.core.interfaces.cache import ICache +from antarest.core.interfaces.eventbus import IEventBus +from antarest.core.requests import RequestParameters +from antarest.core.tasks.model import ( + CustomTaskEventMessages, + TaskDTO, + TaskListFilter, + TaskStatus, + TaskType, + TaskResult, +) +from antarest.core.tasks.service import ITaskService, Task +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.matrixstore.service import SimpleMatrixService +from antarest.matrixstore.uri_resolver_service import UriResolverService +from antarest.study.storage.patch_service import PatchService +from antarest.study.storage.rawstudy.model.filesystem.factory import ( + StudyFactory, +) +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService +from antarest.study.storage.storage_service import StudyStorageService +from antarest.study.storage.variantstudy.business.matrix_constants_generator import ( + GeneratorMatrixConstants, +) +from antarest.study.storage.variantstudy.command_factory import CommandFactory +from antarest.study.storage.variantstudy.repository import ( + VariantStudyRepository, +) +from antarest.study.storage.variantstudy.variant_study_service import ( + VariantStudyService, +) + +__all__ = ( + "bucket_dir_fixture", + "simple_matrix_service_fixture", + "generator_matrix_constants_fixture", + "uri_resolver_service_fixture", + "core_cache_fixture", + "study_factory_fixture", + "core_config_fixture", + "patch_service_fixture", + "task_service_fixture", + "event_bus_fixture", + "command_factory_fixture", + "variant_study_repository_fixture", + "raw_study_service_fixture", + "variant_study_service_fixture", + "study_storage_service_fixture", +) + + +class SynchTaskService(ITaskService): + def __init__(self) -> None: + self._task_result: Optional[TaskResult] = None + + def add_worker_task( + self, + task_type: TaskType, + task_queue: str, + task_args: Dict[str, Union[int, float, bool, str]], + name: Optional[str], + ref_id: Optional[str], + request_params: RequestParameters, + ) -> Optional[str]: + raise NotImplementedError() + + def add_task( + self, + action: Task, + name: Optional[str], + task_type: Optional[TaskType], + ref_id: Optional[str], + custom_event_messages: Optional[CustomTaskEventMessages], + request_params: RequestParameters, + ) -> str: + self._task_result = action(lambda message: None) + return str(uuid.uuid4()) + + def status_task( + self, + task_id: str, + request_params: RequestParameters, + with_logs: bool = False, + ) -> TaskDTO: + return TaskDTO( + id=task_id, + name="mock", + owner=None, + status=TaskStatus.COMPLETED, + creation_date_utc=datetime.datetime.now().isoformat(" "), + completion_date_utc=None, + result=self._task_result, + logs=None, + ) + + def list_tasks( + self, task_filter: TaskListFilter, request_params: RequestParameters + ) -> List[TaskDTO]: + return [] + + def await_task( + self, task_id: str, timeout_sec: Optional[int] = None + ) -> None: + pass + + +@pytest.fixture(name="bucket_dir", scope="session") +def bucket_dir_fixture(tmp_path_factory) -> Path: + """ + Fixture that creates a session-level temporary directory named "matrix_store" for storing matrices. + + This fixture is used with the "session" scope to share the same directory among all tests. + This sharing optimizes test execution speed and reduces disk space usage in the temporary directory. + It is safe to share the directory as matrices have unique identifiers. + + Args: + tmp_path_factory: A fixture provided by pytest to generate temporary directories. + + Returns: + A Path object representing the created temporary directory for storing matrices. + """ + return tmp_path_factory.mktemp("matrix_store", numbered=False) + + +@pytest.fixture(name="simple_matrix_service", scope="session") +def simple_matrix_service_fixture(bucket_dir: Path) -> SimpleMatrixService: + """ + Fixture that creates a SimpleMatrixService instance using the session-level temporary directory. + + Args: + bucket_dir: The session-level temporary directory for storing matrices. + + Returns: + An instance of the SimpleMatrixService class representing the matrix service. + """ + return SimpleMatrixService(bucket_dir) + + +@pytest.fixture(name="generator_matrix_constants", scope="session") +def generator_matrix_constants_fixture( + simple_matrix_service: SimpleMatrixService, +) -> GeneratorMatrixConstants: + """ + Fixture that creates a GeneratorMatrixConstants instance with a session-level scope. + + Args: + simple_matrix_service: An instance of the SimpleMatrixService class. + + Returns: + An instance of the GeneratorMatrixConstants class representing the matrix constants generator. + """ + return GeneratorMatrixConstants(matrix_service=simple_matrix_service) + + +@pytest.fixture(name="uri_resolver_service", scope="session") +def uri_resolver_service_fixture( + simple_matrix_service: SimpleMatrixService, +) -> UriResolverService: + """ + Fixture that creates an UriResolverService instance with a session-level scope. + + Args: + simple_matrix_service: An instance of the SimpleMatrixService class. + + Returns: + An instance of the UriResolverService class representing the URI resolver service. + """ + return UriResolverService(matrix_service=simple_matrix_service) + + +@pytest.fixture(name="core_cache", scope="session") +def core_cache_fixture() -> ICache: + """ + Fixture that creates an ICache instance with a session-level scope. + We use the Mock class to create a dummy cache object that doesn't store any data. + + Returns: + An instance of the Mock class representing the cache object. + """ + cache = Mock(spec=ICache) + cache.get.return_value = None # don't use cache + return cache + + +@pytest.fixture(name="study_factory", scope="session") +def study_factory_fixture( + simple_matrix_service: SimpleMatrixService, + uri_resolver_service: UriResolverService, + core_cache: ICache, +) -> StudyFactory: + """ + Fixture that creates a StudyFactory instance with a session-level scope. + + Args: + simple_matrix_service: An instance of the SimpleMatrixService class. + uri_resolver_service: An instance of the UriResolverService class. + core_cache: An instance of the ICache class. + + Returns: + An instance of the StudyFactory class representing the study factory used for all tests. + """ + return StudyFactory( + matrix=simple_matrix_service, + resolver=uri_resolver_service, + cache=core_cache, + ) + + +@pytest.fixture(name="core_config") +def core_config_fixture( + tmp_path: Path, + project_path: Path, + bucket_dir: Path, +) -> Config: + """ + Fixture that creates a Config instance for the core application configuration. + + Args: + tmp_path: A Path object representing the temporary directory provided by pytest. + project_path: A Path object representing the project's directory. + bucket_dir: A Path object representing the directory for storing matrices. + + Returns: + An instance of the Config class with the provided configuration settings. + """ + tmp_dir = tmp_path.joinpath("tmp") + tmp_dir.mkdir(exist_ok=True) + return Config( + storage=StorageConfig( + matrixstore=bucket_dir, + archive_dir=tmp_path.joinpath("archives"), + tmp_dir=tmp_dir, + workspaces={ + "default": WorkspaceConfig( + path=tmp_path.joinpath("internal_studies"), + ), + "studies": WorkspaceConfig( + path=tmp_path.joinpath("studies"), + ), + }, + ), + resources_path=project_path.joinpath("resources"), + root_path=str(tmp_path), + ) + + +@pytest.fixture(name="patch_service", scope="session") +def patch_service_fixture() -> PatchService: + """ + Fixture that creates a PatchService instance with a session-level scope. + + Returns: + An instance of the PatchService class with the default repository setting as None. + """ + return PatchService(repository=None) + + +@pytest.fixture(name="task_service", scope="session") +def task_service_fixture() -> ITaskService: + """ + Fixture that creates a Mock instance of ITaskService with a session-level scope. + + Returns: + A Mock instance of the ITaskService class for task-related testing. + """ + return SynchTaskService() + + +@pytest.fixture(name="event_bus", scope="session") +def event_bus_fixture() -> IEventBus: + """ + Fixture that creates a Mock instance of IEventBus with a session-level scope. + + Returns: + A Mock instance of the IEventBus class for event bus-related testing. + """ + return Mock(spec=IEventBus) + + +@pytest.fixture(name="command_factory", scope="session") +def command_factory_fixture( + generator_matrix_constants: GeneratorMatrixConstants, + simple_matrix_service: SimpleMatrixService, + patch_service: PatchService, +) -> CommandFactory: + """ + Fixture that creates a CommandFactory instance with a session-level scope. + + Args: + generator_matrix_constants: An instance of the GeneratorMatrixConstants class. + simple_matrix_service: An instance of the SimpleMatrixService class. + patch_service: An instance of the PatchService class. + + Returns: + An instance of the CommandFactory class with the provided dependencies. + """ + return CommandFactory( + generator_matrix_constants=generator_matrix_constants, + matrix_service=simple_matrix_service, + patch_service=patch_service, + ) + + +# noinspection PyUnusedLocal +@pytest.fixture(name="variant_study_repository") +def variant_study_repository_fixture( + core_cache: ICache, + db_middleware: DBSessionMiddleware, # required +) -> VariantStudyRepository: + """ + Fixture that creates a VariantStudyRepository instance. + + Args: + core_cache: An instance of the ICache class. + db_middleware: An instance of the DBSessionMiddleware class. + + Returns: + An instance of the VariantStudyRepository class with the provided cache service. + """ + return VariantStudyRepository(cache_service=core_cache) + + +@pytest.fixture(name="raw_study_service") +def raw_study_service_fixture( + core_config: Config, + study_factory: StudyFactory, + patch_service: PatchService, + core_cache: ICache, +) -> RawStudyService: + """ + Fixture that creates a RawStudyService instance. + + Args: + core_config: An instance of the Config class representing the core application configuration. + study_factory: An instance of the StudyFactory class. + patch_service: An instance of the PatchService class. + core_cache: An instance of the ICache class. + + Returns: + An instance of the RawStudyService class with the provided dependencies. + """ + return RawStudyService( + config=core_config, + study_factory=study_factory, + path_resources=core_config.resources_path, + patch_service=patch_service, + cache=core_cache, + ) + + +@pytest.fixture(name="variant_study_service") +def variant_study_service_fixture( + task_service: ITaskService, + core_cache: ICache, + raw_study_service: RawStudyService, + command_factory: CommandFactory, + study_factory: StudyFactory, + patch_service: PatchService, + variant_study_repository: VariantStudyRepository, + event_bus: IEventBus, + core_config: Config, +) -> VariantStudyService: + """ + Fixture that creates a VariantStudyService instance. + + Args: + task_service: An instance of the ITaskService class. + core_cache: An instance of the ICache class. + raw_study_service: An instance of the RawStudyService class. + command_factory: An instance of the CommandFactory class. + study_factory: An instance of the StudyFactory class. + patch_service: An instance of the PatchService class. + variant_study_repository: An instance of the VariantStudyRepository class. + event_bus: An instance of the IEventBus class. + core_config: An instance of the Config class representing the core application configuration. + + Returns: + An instance of the VariantStudyService class with the provided dependencies. + """ + return VariantStudyService( + task_service=task_service, + cache=core_cache, + raw_study_service=raw_study_service, + command_factory=command_factory, + study_factory=study_factory, + patch_service=patch_service, + repository=variant_study_repository, + event_bus=event_bus, + config=core_config, + ) + + +@pytest.fixture(name="study_storage_service") +def study_storage_service_fixture( + raw_study_service: RawStudyService, + variant_study_service: VariantStudyService, +) -> StudyStorageService: + """ + Fixture that creates a StudyStorageService instance for study storage-related testing. + + Args: + raw_study_service: The RawStudyService instance. + variant_study_service: The VariantStudyService instance. + + Returns: + An instance of the StudyStorageService class representing the study storage service. + """ + return StudyStorageService( + raw_study_service=raw_study_service, + variant_study_service=variant_study_service, + ) diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 7375bff26c..3f342d5b7e 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -28,7 +28,7 @@ from antarest.eventbus.business.local_eventbus import LocalEventBus from antarest.eventbus.service import EventBusService from antarest.worker.worker import AbstractWorker, WorkerTaskCommand -from tests.conftest import with_db_context +from tests.helpers import with_db_context def test_service() -> None: diff --git a/tests/eventbus/test_service.py b/tests/eventbus/test_service.py index 542daa5bd3..1846cad077 100644 --- a/tests/eventbus/test_service.py +++ b/tests/eventbus/test_service.py @@ -5,7 +5,7 @@ from antarest.core.interfaces.eventbus import Event, EventType from antarest.core.model import PermissionInfo, PublicMode from antarest.eventbus.main import build_eventbus -from tests.conftest import auto_retry_assert +from tests.helpers import auto_retry_assert def test_service_factory(): diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000000..974ac2a054 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,87 @@ +import time +from datetime import datetime, timedelta, timezone +from functools import wraps +from typing import Any, Callable, Dict, List, cast + +import numpy as np +from numpy import typing as npt + +from antarest.core.model import SUB_JSON +from antarest.core.utils.fastapi_sqlalchemy import db + + +def with_db_context(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> Any: + with db(): + return f(*args, **kwargs) + + return wrapper + + +def _assert_dict(a: Dict[str, Any], b: Dict[str, Any]) -> None: + if a.keys() != b.keys(): + raise AssertionError( + f"study level has not the same keys {a.keys()} != {b.keys()}" + ) + for k, v in a.items(): + assert_study(v, b[k]) + + +def _assert_list(a: List[Any], b: List[Any]) -> None: + for i, j in zip(a, b): + assert_study(i, j) + + +def _assert_pointer_path(a: str, b: str) -> None: + # pointer is like studyfile://study-id/a/b/c + # we should compare a/b/c only + if a.split("/")[4:] != b.split("/")[4:]: + raise AssertionError(f"element in study not the same {a} != {b}") + + +def _assert_others(a: Any, b: Any) -> None: + if a != b: + raise AssertionError(f"element in study not the same {a} != {b}") + + +def _assert_array( + a: npt.NDArray[np.float64], + b: npt.NDArray[np.float64], +) -> None: + # noinspection PyUnresolvedReferences + if not (a == b).all(): + raise AssertionError(f"element in study not the same {a} != {b}") + + +def assert_study(a: SUB_JSON, b: SUB_JSON) -> None: + if isinstance(a, dict) and isinstance(b, dict): + _assert_dict(a, b) + elif isinstance(a, list) and isinstance(b, list): + _assert_list(a, b) + elif ( + isinstance(a, str) + and isinstance(b, str) + and "studyfile://" in a + and "studyfile://" in b + ): + _assert_pointer_path(a, b) + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + _assert_array(a, b) + elif isinstance(a, np.ndarray) and isinstance(b, list): + _assert_list(cast(List[float], a.tolist()), b) + elif isinstance(a, list) and isinstance(b, np.ndarray): + _assert_list(a, cast(List[float], b.tolist())) + else: + _assert_others(a, b) + + +def auto_retry_assert( + predicate: Callable[..., bool], timeout: int = 2, delay: float = 0.2 +) -> None: + threshold = datetime.now(timezone.utc) + timedelta(seconds=timeout) + while datetime.now(timezone.utc) < threshold: + if predicate(): + return + time.sleep(delay) + raise AssertionError() diff --git a/tests/launcher/test_extension_adequacy_patch.py b/tests/launcher/test_extension_adequacy_patch.py index 41711ff098..54970088e7 100644 --- a/tests/launcher/test_extension_adequacy_patch.py +++ b/tests/launcher/test_extension_adequacy_patch.py @@ -7,7 +7,7 @@ from antarest.launcher.extensions.adequacy_patch.extension import ( AdequacyPatchExtension, ) -from tests.conftest import with_db_context +from tests.helpers import with_db_context @with_db_context diff --git a/tests/launcher/test_repository.py b/tests/launcher/test_repository.py index d103d2c3e8..a84b5099a0 100644 --- a/tests/launcher/test_repository.py +++ b/tests/launcher/test_repository.py @@ -6,12 +6,12 @@ from sqlalchemy import create_engine from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import db, DBSessionMiddleware -from antarest.launcher.model import JobResult, JobStatus, JobLog, JobLogType +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db +from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus from antarest.launcher.repository import JobResultRepository from antarest.study.model import RawStudy from antarest.study.repository import StudyMetadataRepository -from tests.conftest import with_db_context +from tests.helpers import with_db_context @pytest.mark.unit_test diff --git a/tests/matrixstore/conftest.py b/tests/matrixstore/conftest.py index 14e43a58b3..9567925b97 100644 --- a/tests/matrixstore/conftest.py +++ b/tests/matrixstore/conftest.py @@ -1,32 +1,13 @@ import unittest.mock import pytest -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware -from antarest.dbmodel import Base + from antarest.matrixstore.repository import ( MatrixContentRepository, MatrixDataSetRepository, MatrixRepository, ) from antarest.matrixstore.service import MatrixService -from sqlalchemy import create_engine - - -@pytest.fixture(name="db_engine") -def db_engine_fixture(): - engine = create_engine("sqlite:///:memory:") - Base.metadata.create_all(engine) - yield engine - engine.dispose() - - -@pytest.fixture(name="db_middleware", autouse=True) -def db_middleware_fixture(db_engine): - yield DBSessionMiddleware( - None, - custom_engine=db_engine, - session_args={"autocommit": False, "autoflush": False}, - ) @pytest.fixture(name="matrix_repo") diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index b462305662..3638959b64 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -5,6 +5,7 @@ from zipfile import ZipFile import pytest + from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db @@ -51,7 +52,7 @@ from antarest.study.storage.variantstudy.variant_study_service import ( VariantStudyService, ) -from tests.conftest import with_db_context +from tests.helpers import with_db_context from tests.storage.business.assets import ASSETS_DIR diff --git a/tests/storage/business/test_autoarchive_service.py b/tests/storage/business/test_autoarchive_service.py index 9cfc05ac2a..6187633a7b 100644 --- a/tests/storage/business/test_autoarchive_service.py +++ b/tests/storage/business/test_autoarchive_service.py @@ -6,10 +6,10 @@ from antarest.core.exceptions import TaskAlreadyRunning from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.requests import RequestParameters -from antarest.study.model import RawStudy, DEFAULT_WORKSPACE_NAME +from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy from antarest.study.storage.auto_archive_service import AutoArchiveService from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy -from tests.conftest import with_db_context +from tests.helpers import with_db_context @with_db_context diff --git a/tests/storage/business/test_export.py b/tests/storage/business/test_export.py index df4697a004..ab63e555b0 100644 --- a/tests/storage/business/test_export.py +++ b/tests/storage/business/test_export.py @@ -43,8 +43,8 @@ def test_export_file(tmp_path: Path, outputs: bool): (root / "test").mkdir() (root / "test/file.txt").write_text("Bonjour") (root / "file.txt").write_text("Hello, World") - (root / "output").mkdir() - (root / "output/file.txt").write_text("42") + (root / "output/results1").mkdir(parents=True) + (root / "output/results1/file.txt").write_text("42") export_path = tmp_path / "study.zip" @@ -61,12 +61,11 @@ def test_export_file(tmp_path: Path, outputs: bool): study_tree = Mock() study_factory.create_from_fs.return_value = study_tree - study_service.export_study(study, export_path, outputs) - zipf = ZipFile(export_path) - - assert "file.txt" in zipf.namelist() - assert "test/file.txt" in zipf.namelist() - assert ("output/file.txt" in zipf.namelist()) == outputs + study_service.export_study(study, export_path, outputs=outputs) + with ZipFile(export_path) as zipf: + assert "file.txt" in zipf.namelist() + assert "test/file.txt" in zipf.namelist() + assert ("output/results1/file.txt" in zipf.namelist()) == outputs @pytest.mark.unit_test @@ -78,8 +77,8 @@ def test_export_flat(tmp_path: Path): (root / "test/output").mkdir() (root / "test/output/file.txt").write_text("Test") (root / "file.txt").write_text("Hello, World") - (root / "output").mkdir() - (root / "output/file.txt").write_text("42") + (root / "output/result1").mkdir(parents=True) + (root / "output/result1/file.txt").write_text("42") root_without_output = tmp_path / "folder-without-output" root_without_output.mkdir() diff --git a/tests/storage/business/test_patch_service.py b/tests/storage/business/test_patch_service.py index 93820d23f8..6d7651a96e 100644 --- a/tests/storage/business/test_patch_service.py +++ b/tests/storage/business/test_patch_service.py @@ -6,6 +6,7 @@ from unittest.mock import Mock import pytest + from antarest.core.model import PublicMode from antarest.core.utils.fastapi_sqlalchemy import db from antarest.study.model import ( @@ -20,7 +21,7 @@ ) from antarest.study.repository import StudyMetadataRepository from antarest.study.storage.patch_service import PatchService -from tests.conftest import with_db_context +from tests.helpers import with_db_context PATCH_CONTENT = """ { diff --git a/tests/storage/integration/test_STA_mini.py b/tests/storage/integration/test_STA_mini.py index e00b2efc7d..48da21a18a 100644 --- a/tests/storage/integration/test_STA_mini.py +++ b/tests/storage/integration/test_STA_mini.py @@ -7,6 +7,9 @@ from unittest.mock import Mock import pytest +from fastapi import FastAPI +from starlette.testclient import TestClient + from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTGroup, JWTUser from antarest.core.model import JSON from antarest.core.requests import RequestParameters @@ -14,9 +17,7 @@ from antarest.matrixstore.service import MatrixService from antarest.study.main import build_study_service from antarest.study.service import StudyService -from fastapi import FastAPI -from starlette.testclient import TestClient -from tests.conftest import assert_study +from tests.helpers import assert_study from tests.storage.integration.data.de_details_hourly import de_details_hourly from tests.storage.integration.data.de_fr_values_hourly import ( de_fr_values_hourly, diff --git a/tests/storage/repository/test_study.py b/tests/storage/repository/test_study.py index 20ba4ec310..c2a6c146ad 100644 --- a/tests/storage/repository/test_study.py +++ b/tests/storage/repository/test_study.py @@ -1,137 +1,109 @@ from datetime import datetime -from unittest.mock import Mock - -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker # type: ignore from antarest.core.cache.business.local_chache import LocalCache from antarest.core.interfaces.cache import CacheConstants -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db -from antarest.login.model import User, Group +from antarest.login.model import Group, User from antarest.study.common.utils import get_study_information from antarest.study.model import ( - Study, - RawStudy, DEFAULT_WORKSPACE_NAME, - StudyContentStatus, PublicMode, + RawStudy, + Study, + StudyContentStatus, ) from antarest.study.repository import StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy -from tests.conftest import with_db_context +from sqlalchemy.orm import scoped_session, sessionmaker # type: ignore +from tests.helpers import with_db_context +@with_db_context def test_cyclelife(): - engine = create_engine("sqlite:///:memory:", echo=False) - user = User(id=0, name="admin") group = Group(id="my-group", name="group") - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, + repo = StudyMetadataRepository(LocalCache()) + a = Study( + name="a", + version="42", + author="John Smith", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], + ) + b = RawStudy( + name="b", + version="43", + author="Morpheus", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], + ) + c = RawStudy( + name="c", + version="43", + author="Trinity", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], + missing=datetime.utcnow(), + ) + d = VariantStudy( + name="d", + version="43", + author="Mr. Anderson", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], ) - with db(): - repo = StudyMetadataRepository(LocalCache()) - a = Study( - name="a", - version="42", - author="John Smith", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - public_mode=PublicMode.FULL, - owner=user, - groups=[group], - ) - b = RawStudy( - name="b", - version="43", - author="Morpheus", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - public_mode=PublicMode.FULL, - owner=user, - groups=[group], - ) - c = RawStudy( - name="c", - version="43", - author="Trinity", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - public_mode=PublicMode.FULL, - owner=user, - groups=[group], - missing=datetime.utcnow(), - ) - d = VariantStudy( - name="d", - version="43", - author="Mr. Anderson", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - public_mode=PublicMode.FULL, - owner=user, - groups=[group], - ) - - a = repo.save(a) - b = repo.save(b) - repo.save(c) - repo.save(d) - assert b.id - c = repo.get(a.id) - assert a == c - - assert len(repo.get_all()) == 3 - assert len(repo.get_all_raw(show_missing=True)) == 2 - assert len(repo.get_all_raw(show_missing=False)) == 1 - - repo.delete(a.id) - assert repo.get(a.id) is None + a = repo.save(a) + b = repo.save(b) + repo.save(c) + repo.save(d) + assert b.id + c = repo.get(a.id) + assert a == c + assert len(repo.get_all()) == 3 + assert len(repo.get_all_raw(show_missing=True)) == 2 + assert len(repo.get_all_raw(show_missing=False)) == 1 + + repo.delete(a.id) + assert repo.get(a.id) is None -def test_study_inheritance(): - engine = create_engine("sqlite:///:memory:", echo=False) - sess = scoped_session( - sessionmaker(autocommit=False, autoflush=False, bind=engine) - ) +@with_db_context +def test_study_inheritance(): user = User(id=0, name="admin") group = Group(id="my-group", name="group") - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, + repo = StudyMetadataRepository(LocalCache()) + a = RawStudy( + name="a", + version="42", + author="John Smith", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], + workspace=DEFAULT_WORKSPACE_NAME, + path="study", + content_status=StudyContentStatus.WARNING, ) - with db(): - repo = StudyMetadataRepository(LocalCache()) - a = RawStudy( - name="a", - version="42", - author="John Smith", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - public_mode=PublicMode.FULL, - owner=user, - groups=[group], - workspace=DEFAULT_WORKSPACE_NAME, - path="study", - content_status=StudyContentStatus.WARNING, - ) - - repo.save(a) - b = repo.get(a.id) - - assert isinstance(b, RawStudy) - assert b.path == "study" + repo.save(a) + b = repo.get(a.id) + + assert isinstance(b, RawStudy) + assert b.path == "study" @with_db_context @@ -141,28 +113,27 @@ def test_cache(): cache = LocalCache() - with db(): - repo = StudyMetadataRepository(cache) - a = RawStudy( - name="a", - version="42", - author="John Smith", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - public_mode=PublicMode.FULL, - owner=user, - groups=[group], - workspace=DEFAULT_WORKSPACE_NAME, - path="study", - content_status=StudyContentStatus.WARNING, - ) - - repo.save(a) - cache.put( - CacheConstants.STUDY_LISTING.value, - {a.id: get_study_information(a)}, - ) - repo.save(a) - repo.delete(a.id) - - assert len(cache.get(CacheConstants.STUDY_LISTING.value)) == 0 + repo = StudyMetadataRepository(cache) + a = RawStudy( + name="a", + version="42", + author="John Smith", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], + workspace=DEFAULT_WORKSPACE_NAME, + path="study", + content_status=StudyContentStatus.WARNING, + ) + + repo.save(a) + cache.put( + CacheConstants.STUDY_LISTING.value, + {a.id: get_study_information(a)}, + ) + repo.save(a) + repo.delete(a.id) + + assert len(cache.get(CacheConstants.STUDY_LISTING.value)) == 0 diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index f9ba58119b..a62a2a2ed8 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -8,6 +8,7 @@ from uuid import uuid4 import pytest + from antarest.core.config import Config, StorageConfig, WorkspaceConfig from antarest.core.exceptions import TaskAlreadyRunning from antarest.core.filetransfer.model import FileDownload, FileDownloadTaskDTO @@ -84,7 +85,7 @@ VariantStudyService, ) from antarest.worker.archive_worker import ArchiveTaskArgs -from tests.conftest import with_db_context +from tests.helpers import with_db_context def build_study_service( diff --git a/tests/study/business/test_allocation_manager.py b/tests/study/business/test_allocation_manager.py index 87326b9cea..0983a066d9 100644 --- a/tests/study/business/test_allocation_manager.py +++ b/tests/study/business/test_allocation_manager.py @@ -1,25 +1,21 @@ -import contextlib import datetime import re import uuid from unittest.mock import Mock, patch import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker from antarest.core.exceptions import AllocationDataNotFound, AreaNotFound from antarest.core.model import PublicMode -from antarest.dbmodel import Base -from antarest.login.model import User, Group +from antarest.login.model import Group, User from antarest.study.business.allocation_management import ( AllocationField, AllocationFormFields, - AllocationMatrix, AllocationManager, + AllocationMatrix, ) from antarest.study.business.area_management import AreaInfoDTO, AreaType -from antarest.study.model import Study, StudyContentStatus, RawStudy +from antarest.study.model import RawStudy, Study, StudyContentStatus from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.rawstudy.model.filesystem.root.filestudytree import ( FileStudyTree, @@ -40,12 +36,6 @@ VariantStudyService, ) -from antarest.study.business.allocation_management import ( - AllocationManager, - AllocationField, - AllocationFormFields, -) - class TestAllocationField: def test_base(self): @@ -179,21 +169,6 @@ def test_validation_matrix_no_non_null_values(self): ) -@pytest.fixture(scope="function", name="db_engine") -def db_engine_fixture(): - engine = create_engine("sqlite:///:memory:") - Base.metadata.create_all(engine) - yield engine - engine.dispose() - - -@pytest.fixture(scope="function", name="db_session") -def db_session_fixture(db_engine): - make_session = sessionmaker(bind=db_engine) - with contextlib.closing(make_session()) as session: - yield session - - # noinspection SpellCheckingInspection EXECUTE_OR_ADD_COMMANDS = ( "antarest.study.business.allocation_management.execute_or_add_commands" diff --git a/tests/study/business/test_correlation_manager.py b/tests/study/business/test_correlation_manager.py index 57050249f3..715f2a91ee 100644 --- a/tests/study/business/test_correlation_manager.py +++ b/tests/study/business/test_correlation_manager.py @@ -1,13 +1,12 @@ -import contextlib import datetime import uuid from unittest.mock import Mock, patch import numpy as np import pytest + from antarest.core.exceptions import AreaNotFound from antarest.core.model import PublicMode -from antarest.dbmodel import Base from antarest.login.model import Group, User from antarest.study.business.area_management import AreaInfoDTO, AreaType from antarest.study.business.correlation_management import ( @@ -36,8 +35,6 @@ from antarest.study.storage.variantstudy.variant_study_service import ( VariantStudyService, ) -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker class TestCorrelationField: @@ -157,21 +154,6 @@ def test_validation__matrix_not_symmetric(self): ) -@pytest.fixture(scope="function", name="db_engine") -def db_engine_fixture(): - engine = create_engine("sqlite:///:memory:") - Base.metadata.create_all(engine) - yield engine - engine.dispose() - - -@pytest.fixture(scope="function", name="db_session") -def db_session_fixture(db_engine): - make_session = sessionmaker(bind=db_engine) - with contextlib.closing(make_session()) as session: - yield session - - # noinspection SpellCheckingInspection EXECUTE_OR_ADD_COMMANDS = ( "antarest.study.business.correlation_management.execute_or_add_commands" diff --git a/tests/study/storage/rawstudy/__init__.py b/tests/study/storage/rawstudy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/study/storage/rawstudy/test_raw_study_service.py b/tests/study/storage/rawstudy/test_raw_study_service.py new file mode 100644 index 0000000000..b2f9d36e1e --- /dev/null +++ b/tests/study/storage/rawstudy/test_raw_study_service.py @@ -0,0 +1,243 @@ +import datetime +import zipfile +from pathlib import Path +from typing import List, Optional + +import numpy as np +import pytest +from antarest.core.model import PublicMode +from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.login.model import Group, User +from antarest.matrixstore.service import SimpleMatrixService +from antarest.study.business.utils import execute_or_add_commands +from antarest.study.model import RawStudy, StudyAdditionalData +from antarest.study.storage.patch_service import PatchService +from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import ( + STStorageConfig, + STStorageGroup, +) +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService +from antarest.study.storage.storage_service import StudyStorageService +from antarest.study.storage.variantstudy.business.matrix_constants_generator import ( + GeneratorMatrixConstants, +) +from antarest.study.storage.variantstudy.model.command.create_area import ( + CreateArea, +) +from antarest.study.storage.variantstudy.model.command.create_st_storage import ( + CreateSTStorage, +) +from antarest.study.storage.variantstudy.model.command_context import ( + CommandContext, +) +from sqlalchemy import create_engine # type: ignore +from tests.helpers import with_db_context + + +class TestRawStudyService: + # noinspection SpellCheckingInspection + """ + This class uses the `db_middleware` instance which is automatically created + for each test method (the fixture has `autouse=True`). + """ + + @pytest.mark.parametrize( + "outputs", + [ + pytest.param(True, id="outputs_yes"), + pytest.param(False, id="no_outputs"), + ], + ) + @pytest.mark.parametrize( + "output_filter", + [ + # fmt:off + pytest.param(None, id="no_filter"), + pytest.param(["20230802-1425eco"], id="folder"), + pytest.param(["20230802-1628eco"], id="zipped"), + pytest.param(["20230802-1425eco", "20230802-1628eco"], id="both"), + # fmt:on + ], + ) + @pytest.mark.parametrize( + "denormalize", + [ + pytest.param(True, id="denormalize_yes"), + pytest.param(False, id="denormalize_no"), + ], + ) + @with_db_context + def test_export_study_flat( + self, + tmp_path: Path, + raw_study_service: RawStudyService, + simple_matrix_service: SimpleMatrixService, + generator_matrix_constants: GeneratorMatrixConstants, + patch_service: PatchService, + study_storage_service: StudyStorageService, + # pytest parameters + outputs: bool, + output_filter: Optional[List[str]], + denormalize: bool, + ) -> None: + ## Prepare database objects + # noinspection PyArgumentList + user = User(id=0, name="admin") + db.session.add(user) + db.session.commit() + + # noinspection PyArgumentList + group = Group(id="my-group", name="group") + db.session.add(group) + db.session.commit() + + raw_study_path = tmp_path / "My RAW Study" + # noinspection PyArgumentList + raw_study = RawStudy( + id="my_raw_study", + name=raw_study_path.name, + version="860", + author="John Smith", + created_at=datetime.datetime(2023, 7, 15, 16, 45), + updated_at=datetime.datetime(2023, 7, 19, 8, 15), + last_access=datetime.datetime.utcnow(), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], + path=str(raw_study_path), + additional_data=StudyAdditionalData(author="John Smith"), + ) + db.session.add(raw_study) + db.session.commit() + + ## Prepare the RAW Study + raw_study_service.create(raw_study) + file_study = raw_study_service.get_raw(raw_study) + + command_context = CommandContext( + generator_matrix_constants=generator_matrix_constants, + matrix_service=simple_matrix_service, + patch_service=patch_service, + ) + + create_area_fr = CreateArea( + command_context=command_context, + area_name="fr", + ) + + # noinspection SpellCheckingInspection + pmax_injection = np.random.rand(8760, 1) + inflows = np.random.uniform(0, 1000, size=(8760, 1)) + + # noinspection PyArgumentList,PyTypeChecker + create_st_storage = CreateSTStorage( + command_context=command_context, + area_id="fr", + parameters=STStorageConfig( + id="", # will be calculated ;-) + name="Storage1", + group=STStorageGroup.BATTERY, + injection_nominal_capacity=1500, + withdrawal_nominal_capacity=1500, + reservoir_capacity=20000, + efficiency=0.94, + initial_level_optim=True, + ), + pmax_injection=pmax_injection.tolist(), + inflows=inflows.tolist(), + ) + + execute_or_add_commands( + raw_study, + file_study, + commands=[create_area_fr, create_st_storage], + storage_service=study_storage_service, + ) + + ## Prepare fake outputs + my_solver_outputs = ["20230802-1425eco", "20230802-1628eco.zip"] + for filename in my_solver_outputs: + output_path = raw_study_path / "output" / filename + # To simplify the checking, there is only one file in each output: + if output_path.suffix.lower() == ".zip": + # Create a fake ZIP file + output_path.parent.mkdir(exist_ok=True, parents=True) + with zipfile.ZipFile( + output_path, + mode="w", + compression=zipfile.ZIP_DEFLATED, + ) as zf: + zf.writestr("simulation.log", data="Simulation done") + else: + # Create a directory + output_path.mkdir(exist_ok=True, parents=True) + (output_path / "simulation.log").write_text("Simulation done") + + ## Collect all files by types to prepare the comparison + src_study_files = set() + src_matrices = set() + src_outputs = set() + for study_file in raw_study_path.rglob("*.*"): + relpath = study_file.relative_to(raw_study_path).as_posix() + if study_file.suffixes == [".txt", ".link"]: + src_matrices.add(relpath.replace(".link", "")) + elif relpath.startswith("output/"): + src_outputs.add(relpath) + else: + src_study_files.add(relpath) + + ## Run the export + target_path = tmp_path / raw_study_path.with_suffix(".exported").name + raw_study_service.export_study_flat( + raw_study, + target_path, + outputs=outputs, + output_list_filter=output_filter, + denormalize=denormalize, + ) + + ## Collect the resulting files + res_study_files = set() + res_matrices = set() + res_outputs = set() + for study_file in target_path.rglob("*.*"): + relpath = study_file.relative_to(target_path).as_posix() + if study_file.suffixes == [".txt", ".link"]: + res_matrices.add(relpath.replace(".link", "")) + elif relpath.startswith("output/"): + res_outputs.add(relpath) + else: + res_study_files.add(relpath) + + ## Check the matrice + # If de-normalization is enabled, the previous loop won't find the matrices + # because the matrix extensions are ".txt" instead of ".txt.link". + # Therefore, it is necessary to move the corresponding ".txt" files + # from `res_study_files` to `res_matrices`. + if denormalize: + assert not res_matrices, "All matrices must be denormalized" + res_matrices = {f for f in res_study_files if f in src_matrices} + res_study_files -= res_matrices + assert res_matrices == src_matrices + + ## Check the outputs + if outputs: + # If `outputs` is True the filtering can occurs + if output_filter is None: + expected_filter = { + f.replace(".zip", "") for f in my_solver_outputs + } + else: + expected_filter = set(output_filter) + expected = { + f"output/{output_name}/simulation.log" + for output_name in expected_filter + } + assert res_outputs == expected + else: + # If `outputs` is False, no output must be exported + # whatever the value of the `output_list_filter` is + assert not res_outputs + + ## Check the study files + assert res_study_files == src_study_files diff --git a/tests/study/storage/test_abstract_storage_service.py b/tests/study/storage/test_abstract_storage_service.py new file mode 100644 index 0000000000..4d3a6f3265 --- /dev/null +++ b/tests/study/storage/test_abstract_storage_service.py @@ -0,0 +1,166 @@ +import datetime +import zipfile +from pathlib import Path +from typing import List, Optional +from unittest.mock import Mock, call + +from antarest.core.config import Config, StorageConfig +from antarest.core.interfaces.cache import ICache +from antarest.core.model import PublicMode +from antarest.core.requests import RequestParameters +from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.login.model import Group, User +from antarest.study.model import Study +from antarest.study.storage.abstract_storage_service import ( + AbstractStorageService, +) +from antarest.study.storage.patch_service import PatchService +from antarest.study.storage.rawstudy.model.filesystem.config.model import ( + FileStudyTreeConfigDTO, +) +from antarest.study.storage.rawstudy.model.filesystem.factory import ( + FileStudy, + StudyFactory, +) +from tests.helpers import with_db_context + + +class MyStorageService(AbstractStorageService): + """ + This class is only defined to test `AbstractStorageService` class PUBLIC methods. + Abstract methods are not implemented: there are not used or patched with a Mock object. + """ + + def create(self, metadata: Study) -> Study: + raise NotImplementedError + + def exists(self, metadata: Study) -> bool: + raise NotImplementedError + + # noinspection SpellCheckingInspection + def copy( + self, src_meta: Study, dest_name: str, with_outputs: bool = False + ) -> Study: + raise NotImplementedError + + def get_raw( + self, + metadata: Study, + use_cache: bool = True, + output_dir: Optional[Path] = None, + ) -> FileStudy: + raise NotImplementedError + + def set_reference_output( + self, metadata: Study, output_id: str, status: bool + ) -> None: + raise NotImplementedError + + def delete(self, metadata: Study) -> None: + raise NotImplementedError + + def delete_output(self, metadata: Study, output_id: str) -> None: + raise NotImplementedError + + def get_study_path(self, metadata: Study) -> Path: + raise NotImplementedError + + def export_study_flat( + self, + metadata: Study, + dst_path: Path, + outputs: bool = True, + output_list_filter: Optional[List[str]] = None, + denormalize: bool = True, + ) -> None: + raise NotImplementedError + + def get_synthesis( + self, metadata: Study, params: Optional[RequestParameters] = None + ) -> FileStudyTreeConfigDTO: + raise NotImplementedError + + def initialize_additional_data(self, study: Study) -> bool: + raise NotImplementedError + + +class TmpCopy(object): + """A helper object that compares equal if a folder is a "tmp_copy" folder.""" + + def __init__(self, tmp_path: Path): + self.tmp_path = tmp_path + + def __eq__(self, other: Path): + if isinstance(other, Path) and other.name == "tmp_copy": + # `is_relative_to` is not available for Python < 3.9 + try: + other.relative_to(self.tmp_path) + return True + except ValueError: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return f"" + + +class TestAbstractStorageService: + @with_db_context + def test_export_study(self, tmp_path: Path) -> None: + tmp_dir = tmp_path / "tmp" + tmp_dir.mkdir() + study_path = tmp_path / "My Study" + + service = MyStorageService( + config=Config(storage=StorageConfig(tmp_dir=tmp_dir)), + study_factory=Mock(spec=StudyFactory), + patch_service=Mock(spec=PatchService), + cache=Mock(spec=ICache), + ) + + ## Prepare database objects + + # noinspection PyArgumentList + user = User(id=0, name="admin") + db.session.add(user) + db.session.commit() + + # noinspection PyArgumentList + group = Group(id="my-group", name="group") + db.session.add(group) + db.session.commit() + + # noinspection PyArgumentList + metadata = Study( + name="My Study", + version="860", + author="John Smith", + created_at=datetime.datetime(2023, 7, 19, 16, 45), + updated_at=datetime.datetime(2023, 7, 27, 8, 15), + last_access=datetime.datetime.utcnow(), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], + path=str(study_path), + ) + db.session.add(metadata) + db.session.commit() + + ## Check the `export_study` function + service.export_study_flat = Mock(return_value=None) + target_path = tmp_path / "export.zip" + actual = service.export_study(metadata, target_path, outputs=True) + assert actual == target_path + + ## Check the call to export_study_flat + assert service.export_study_flat.mock_calls == [ + call(metadata, TmpCopy(tmp_path), True) + ] + + ## Check that the ZIP file exist and is valid + with zipfile.ZipFile(target_path) as zf: + # Actually, there is nothing is the ZIP file, + # because the Study files doesn't really exist. + assert not zf.namelist() diff --git a/tests/study/storage/variantstudy/test_variant_study_service.py b/tests/study/storage/variantstudy/test_variant_study_service.py new file mode 100644 index 0000000000..555ec7a365 --- /dev/null +++ b/tests/study/storage/variantstudy/test_variant_study_service.py @@ -0,0 +1,249 @@ +import datetime +import re +from pathlib import Path +from unittest.mock import Mock + +import numpy as np +import pytest +from antarest.core.model import PublicMode +from antarest.core.requests import RequestParameters +from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.login.model import Group, User +from antarest.matrixstore.service import SimpleMatrixService +from antarest.study.business.utils import execute_or_add_commands +from antarest.study.model import RawStudy, StudyAdditionalData +from antarest.study.storage.patch_service import PatchService +from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import ( + STStorageConfig, + STStorageGroup, +) +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService +from antarest.study.storage.storage_service import StudyStorageService +from antarest.study.storage.variantstudy.business.matrix_constants_generator import ( + GeneratorMatrixConstants, +) +from antarest.study.storage.variantstudy.model.command.create_area import ( + CreateArea, +) +from antarest.study.storage.variantstudy.model.command.create_st_storage import ( + CreateSTStorage, +) +from antarest.study.storage.variantstudy.model.command_context import ( + CommandContext, +) +from antarest.study.storage.variantstudy.variant_study_service import ( + VariantStudyService, +) +from sqlalchemy import create_engine # type: ignore +from tests.helpers import with_db_context + +# noinspection SpellCheckingInspection +EXPECTED_DENORMALIZED = { + "Desktop.ini", + "input/areas/fr/adequacy_patch.ini", + "input/areas/fr/optimization.ini", + "input/areas/fr/ui.ini", + "input/areas/list.txt", + "input/areas/sets.ini", + "input/bindingconstraints/bindingconstraints.ini", + "input/hydro/allocation/fr.ini", + "input/hydro/common/capacity/creditmodulations_fr.txt.link", + "input/hydro/common/capacity/inflowPattern_fr.txt.link", + "input/hydro/common/capacity/maxpower_fr.txt.link", + "input/hydro/common/capacity/reservoir_fr.txt.link", + "input/hydro/common/capacity/waterValues_fr.txt.link", + "input/hydro/hydro.ini", + "input/hydro/prepro/correlation.ini", + "input/hydro/prepro/fr/energy.txt.link", + "input/hydro/prepro/fr/prepro.ini", + "input/hydro/series/fr/mingen.txt.link", + "input/hydro/series/fr/mod.txt.link", + "input/hydro/series/fr/ror.txt.link", + "input/links/fr/properties.ini", + "input/load/prepro/correlation.ini", + "input/load/prepro/fr/conversion.txt.link", + "input/load/prepro/fr/data.txt.link", + "input/load/prepro/fr/k.txt.link", + "input/load/prepro/fr/settings.ini", + "input/load/prepro/fr/translation.txt.link", + "input/load/series/load_fr.txt.link", + "input/misc-gen/miscgen-fr.txt.link", + "input/renewables/clusters/fr/list.ini", + "input/reserves/fr.txt.link", + "input/solar/prepro/correlation.ini", + "input/solar/prepro/fr/conversion.txt.link", + "input/solar/prepro/fr/data.txt.link", + "input/solar/prepro/fr/k.txt.link", + "input/solar/prepro/fr/settings.ini", + "input/solar/prepro/fr/translation.txt.link", + "input/solar/series/solar_fr.txt.link", + "input/st-storage/clusters/fr/list.ini", + "input/st-storage/series/fr/storage1/PMAX-injection.txt.link", + "input/st-storage/series/fr/storage1/PMAX-withdrawal.txt.link", + "input/st-storage/series/fr/storage1/inflows.txt.link", + "input/st-storage/series/fr/storage1/lower-rule-curve.txt.link", + "input/st-storage/series/fr/storage1/upper-rule-curve.txt.link", + "input/thermal/areas.ini", + "input/thermal/clusters/fr/list.ini", + "input/wind/prepro/correlation.ini", + "input/wind/prepro/fr/conversion.txt.link", + "input/wind/prepro/fr/data.txt.link", + "input/wind/prepro/fr/k.txt.link", + "input/wind/prepro/fr/settings.ini", + "input/wind/prepro/fr/translation.txt.link", + "input/wind/series/wind_fr.txt.link", + "layers/layers.ini", + "settings/comments.txt", + "settings/generaldata.ini", + "settings/resources/study.ico", + "settings/scenariobuilder.dat", + "study.antares", +} + + +class TestVariantStudyService: + @pytest.mark.parametrize( + "denormalize", + [ + pytest.param(True, id="denormalize_yes"), + pytest.param(False, id="denormalize_no"), + ], + ) + @pytest.mark.parametrize( + "from_scratch", + [ + pytest.param(True, id="from_scratch__yes"), + pytest.param(False, id="from_scratch__no"), + ], + ) + @with_db_context + def test_generate_task( + self, + tmp_path: Path, + variant_study_service: VariantStudyService, + raw_study_service: RawStudyService, + simple_matrix_service: SimpleMatrixService, + generator_matrix_constants: GeneratorMatrixConstants, + patch_service: PatchService, + study_storage_service: StudyStorageService, + # pytest parameters + denormalize: bool, + from_scratch: bool, + ) -> None: + ## Prepare database objects + # noinspection PyArgumentList + user = User(id=0, name="admin") + db.session.add(user) + db.session.commit() + + # noinspection PyArgumentList + group = Group(id="my-group", name="group") + db.session.add(group) + db.session.commit() + + ## First create a raw study (root of the variant) + raw_study_path = tmp_path / "My RAW Study" + # noinspection PyArgumentList + raw_study = RawStudy( + id="my_raw_study", + name=raw_study_path.name, + version="860", + author="John Smith", + created_at=datetime.datetime(2023, 7, 15, 16, 45), + updated_at=datetime.datetime(2023, 7, 19, 8, 15), + last_access=datetime.datetime.utcnow(), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], + path=str(raw_study_path), + additional_data=StudyAdditionalData(author="John Smith"), + ) + db.session.add(raw_study) + db.session.commit() + + ## Prepare the RAW Study + raw_study_service.create(raw_study) + + variant_study = variant_study_service.create_variant_study( + raw_study.id, + "My Variant Study", + params=Mock( + spec=RequestParameters, + user=Mock( + impersonator=user.id, is_site_admin=Mock(return_value=True) + ), + ), + ) + + ## Prepare the RAW Study + file_study = variant_study_service.get_raw(variant_study) + + command_context = CommandContext( + generator_matrix_constants=generator_matrix_constants, + matrix_service=simple_matrix_service, + patch_service=patch_service, + ) + + create_area_fr = CreateArea( + command_context=command_context, + area_name="fr", + ) + + ## Prepare the Variant Study Data + # noinspection SpellCheckingInspection + pmax_injection = np.random.rand(8760, 1) + inflows = np.random.uniform(0, 1000, size=(8760, 1)) + + # noinspection PyArgumentList,PyTypeChecker + create_st_storage = CreateSTStorage( + command_context=command_context, + area_id="fr", + parameters=STStorageConfig( + id="", # will be calculated ;-) + name="Storage1", + group=STStorageGroup.BATTERY, + injection_nominal_capacity=1500, + withdrawal_nominal_capacity=1500, + reservoir_capacity=20000, + efficiency=0.94, + initial_level_optim=True, + ), + pmax_injection=pmax_injection.tolist(), + inflows=inflows.tolist(), + ) + + execute_or_add_commands( + variant_study, + file_study, + commands=[create_area_fr, create_st_storage], + storage_service=study_storage_service, + ) + + ## Run the "generate" task + actual_uui = variant_study_service.generate_task( + variant_study, + denormalize=denormalize, + from_scratch=from_scratch, + ) + assert re.fullmatch( + r"[\da-f]{8}-[\da-f]{4}-[\da-f]{4}-[\da-f]{4}-[\da-f]{12}", + actual_uui, + flags=re.IGNORECASE, + ) + + ## Collect the resulting files + workspaces = variant_study_service.config.storage.workspaces + internal_studies_dir: Path = workspaces["default"].path + snapshot_dir = internal_studies_dir.joinpath( + variant_study.snapshot.id, "snapshot" + ) + res_study_files = { + study_file.relative_to(snapshot_dir).as_posix() + for study_file in snapshot_dir.rglob("*.*") + } + + if denormalize: + expected = {f.replace(".link", "") for f in EXPECTED_DENORMALIZED} + else: + expected = EXPECTED_DENORMALIZED + assert res_study_files == expected diff --git a/tests/variantstudy/conftest.py b/tests/variantstudy/conftest.py index 2f6e0d7cc0..ba6a986dc5 100644 --- a/tests/variantstudy/conftest.py +++ b/tests/variantstudy/conftest.py @@ -5,8 +5,6 @@ import numpy as np import pytest -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware -from antarest.dbmodel import Base from antarest.matrixstore.service import MatrixService from antarest.matrixstore.uri_resolver_service import UriResolverService from antarest.study.repository import StudyMetadataRepository @@ -28,28 +26,9 @@ from antarest.study.storage.variantstudy.model.command_context import ( CommandContext, ) -from sqlalchemy import create_engine from tests.variantstudy.assets import ASSETS_DIR -@pytest.fixture(name="db_engine") -def db_engine_fixture(): - engine = create_engine("sqlite:///:memory:") - Base.metadata.create_all(engine) - yield engine - engine.dispose() - - -@pytest.fixture(name="db_middleware", autouse=True) -def db_middleware_fixture(db_engine): - # noinspection SpellCheckingInspection - yield DBSessionMiddleware( - None, - custom_engine=db_engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - @pytest.fixture(name="matrix_service") def matrix_service_fixture() -> MatrixService: """ diff --git a/tests/worker/test_simulator_worker.py b/tests/worker/test_simulator_worker.py index 8ebcd8b80a..2ad13c9fcf 100644 --- a/tests/worker/test_simulator_worker.py +++ b/tests/worker/test_simulator_worker.py @@ -2,18 +2,18 @@ import platform import stat from pathlib import Path -from unittest.mock import Mock, patch, call +from unittest.mock import Mock, call, patch import pytest from antarest.core.config import Config, LauncherConfig, LocalConfig from antarest.worker.simulator_worker import ( - SimulatorWorker, GENERATE_KIRSHOFF_CONSTRAINTS_TASK_NAME, GENERATE_TIMESERIES_TASK_NAME, + SimulatorWorker, ) from antarest.worker.worker import WorkerTaskCommand -from tests.conftest import with_db_context +from tests.helpers import with_db_context @with_db_context diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index 230396256c..9ce059a841 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -4,13 +4,14 @@ from unittest.mock import MagicMock import pytest + from antarest.core.config import Config from antarest.core.interfaces.eventbus import Event, EventType, IEventBus from antarest.core.model import PermissionInfo, PublicMode from antarest.core.tasks.model import TaskResult from antarest.eventbus.main import build_eventbus from antarest.worker.worker import AbstractWorker, WorkerTaskCommand -from tests.conftest import auto_retry_assert +from tests.helpers import auto_retry_assert class DummyWorker(AbstractWorker):