diff --git a/docs/source/conf.py b/docs/source/conf.py index c02a333..47ce565 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,6 +11,7 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. import os +import re from sphinx.ext import apidoc @@ -62,14 +63,63 @@ # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ["_static"] + +# Specify which special members have to be kept +special_members_dict = { + "Document": {"init"}, + "ResponseError": {"init", "or"}, + "PipelineBooleanDict": {"init", "or", "and"}, + "PipelineAttribute": {"init", "or", "and", "eq", "gt", "ge", "lt", "le"}, + "Pipelines": {"init", "add", "iadd"} +} + +# Add trailing and leading "__" to all the aforementioned members +for cls, methods in special_members_dict.items(): + special_members_dict[cls] = {f"__{method}__" for method in methods} + +# Make a set of all allowed special members +all_special_members = set() +for methods in special_members_dict.values(): + all_special_members |= methods + autodoc_default_options = { "members": True, "member-order": "bysource", "private-members": True, "special-members": True, - "undoc-members": True, + "undoc-members": False, } + +def is_special_member(member_name: str) -> bool: + """Checks if the given member is special, i.e. its name has the following format ``____``.""" + return bool(re.compile(r"^__\w+__$").match(member_name)) + + +def skip(app, typ, member_name, obj, flag, options): + """The filter function to determine whether to keep the member in the documentation. + + ``True`` means skip the member. + """ + if is_special_member(member_name): + + if member_name not in all_special_members: + return True + + obj_name = obj.__qualname__.split(".")[0] + if methods_set := special_members_dict.get(obj_name, None): + if member_name in methods_set: + return False # Keep the member + return True + + return None + + +def setup(app): + """Sets up the sphinx app.""" + app.connect("autodoc-skip-member", skip) + + root_doc = "index" output_dir = os.path.join(".") diff --git a/docs/source/index.rst b/docs/source/index.rst index 7091703..4aec776 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,4 +1,4 @@ -Welcome to Pytroll documentation! +Pytroll-db Documentation =========================================== .. toctree:: diff --git a/trolldb/__init__.py b/trolldb/__init__.py index f054e81..8a78256 100644 --- a/trolldb/__init__.py +++ b/trolldb/__init__.py @@ -1 +1 @@ -"""trolldb package.""" +"""The database interface of the Pytroll package.""" diff --git a/trolldb/api/__init__.py b/trolldb/api/__init__.py index 6590a31..4114977 100644 --- a/trolldb/api/__init__.py +++ b/trolldb/api/__init__.py @@ -5,5 +5,5 @@ For more information and documentation, please refer to the following sub-packages and modules: - :obj:`trolldb.api.routes`: The package which defines the API routes. - - :obj:`trollddb.api.api`: The module which defines the API server and how it is run via the given configuration. + - :obj:`trolldb.api.api`: The module which defines the API server and how it is run via the given configuration. """ diff --git a/trolldb/api/api.py b/trolldb/api/api.py index 85272a1..a960ffe 100644 --- a/trolldb/api/api.py +++ b/trolldb/api/api.py @@ -14,19 +14,20 @@ """ import asyncio +import sys import time from contextlib import contextmanager from multiprocessing import Process -from typing import Union +from typing import Any, Generator, NoReturn import uvicorn from fastapi import FastAPI, status from fastapi.responses import PlainTextResponse from loguru import logger -from pydantic import FilePath, validate_call +from pydantic import ValidationError from trolldb.api.routes import api_router -from trolldb.config.config import AppConfig, Timeout, parse_config_yaml_file +from trolldb.config.config import AppConfig, Timeout from trolldb.database.mongodb import mongodb_context from trolldb.errors.errors import ResponseError @@ -34,7 +35,7 @@ title="pytroll-db", summary="The database API of Pytroll", description= - "The API allows you to perform CRUD operations as well as querying the database" + "The API allows you to perform CRUD operations as well as querying the database" "At the moment only MongoDB is supported. It is based on the following Python packages" "\n * **PyMongo** (https://github.com/mongodb/mongo-python-driver)" "\n * **motor** (https://github.com/mongodb/motor)", @@ -43,11 +44,11 @@ url="https://www.gnu.org/licenses/gpl-3.0.en.html" ) ) -"""These will appear int the auto-generated documentation and are passed to the ``FastAPI`` class as keyword args.""" +"""These will appear in the auto-generated documentation and are passed to the ``FastAPI`` class as keyword args.""" -@validate_call -def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: +@logger.catch(onerror=lambda _: sys.exit(1)) +def run_server(config: AppConfig, **kwargs) -> None: """Runs the API server with all the routes and connection to the database. It first creates a FastAPI application and runs it using `uvicorn `_ which is @@ -56,32 +57,26 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: Args: config: - The configuration of the application which includes both the server and database configurations. Its type - should be a :class:`FilePath`, which is a valid path to an existing config file which will parsed as a - ``.YAML`` file. + The configuration of the application which includes both the server and database configurations. **kwargs: The keyword arguments are the same as those accepted by the `FastAPI class `_ and are directly passed to it. These keyword arguments will be first concatenated with the configurations of the API server which are read from the ``config`` argument. The keyword arguments which are passed explicitly to the function - take precedence over ``config``. Finally, ``API_INFO``, which are hard-coded information for the API server, - will be concatenated and takes precedence over all. - - Raises: - ValidationError: - If the function is not called with arguments of valid type. + take precedence over ``config``. Finally, :obj:`API_INFO`, which are hard-coded information for the API + server, will be concatenated and takes precedence over all. Example: .. code-block:: python - from api.api import run_server + from trolldb.api.api import run_server + from trolldb.config.config import parse_config + if __name__ == "__main__": - run_server("config.yaml") + run_server(parse_config("config.yaml")) """ logger.info("Attempt to run the API server ...") - if not isinstance(config, AppConfig): - config = parse_config_yaml_file(config) # Concatenate the keyword arguments for the API server in the order of precedence (lower to higher). app = FastAPI(**(config.api_server._asdict() | kwargs | API_INFO)) @@ -89,7 +84,7 @@ def run_server(config: Union[AppConfig, FilePath], **kwargs) -> None: app.include_router(api_router) @app.exception_handler(ResponseError) - async def auto_exception_handler(_, exc: ResponseError): + async def auto_handler_response_errors(_, exc: ResponseError) -> PlainTextResponse: """Catches all the exceptions raised as a ResponseError, e.g. accessing non-existing databases/collections.""" status_code, message = exc.get_error_details() info = dict( @@ -99,7 +94,13 @@ async def auto_exception_handler(_, exc: ResponseError): logger.error(f"Response error caught by the API auto exception handler: {info}") return PlainTextResponse(**info) - async def _serve(): + @app.exception_handler(ValidationError) + async def auto_handler_pydantic_validation_errors(_, exc: ValidationError) -> PlainTextResponse: + """Catches all the exceptions raised as a Pydantic ValidationError.""" + logger.error(f"Response error caught by the API auto exception handler: {exc}") + return PlainTextResponse(str(exc), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + async def _serve() -> NoReturn: """An auxiliary coroutine to be used in the asynchronous execution of the FastAPI application.""" async with mongodb_context(config.database): logger.info("Attempt to start the uvicorn server ...") @@ -116,7 +117,7 @@ async def _serve(): @contextmanager -def api_server_process_context(config: Union[AppConfig, FilePath], startup_time: Timeout = 2): +def api_server_process_context(config: AppConfig, startup_time: Timeout = 2) -> Generator[Process, Any, None]: """A synchronous context manager to run the API server in a separate process (non-blocking). It uses the `multiprocessing `_ package. The main use case @@ -132,9 +133,6 @@ def api_server_process_context(config: Union[AppConfig, FilePath], startup_time: large so that the tests will not time out. """ logger.info("Attempt to run the API server process in a context manager ...") - if not isinstance(config, AppConfig): - config = parse_config_yaml_file(config) - process = Process(target=run_server, args=(config,)) try: process.start() diff --git a/trolldb/api/routes/__init__.py b/trolldb/api/routes/__init__.py index 4c69061..7378306 100644 --- a/trolldb/api/routes/__init__.py +++ b/trolldb/api/routes/__init__.py @@ -1,4 +1,4 @@ -"""routes package.""" +"""The routes package of the API.""" from .router import api_router diff --git a/trolldb/api/routes/common.py b/trolldb/api/routes/common.py index b05c86b..86bb4f8 100644 --- a/trolldb/api/routes/common.py +++ b/trolldb/api/routes/common.py @@ -2,19 +2,11 @@ from typing import Annotated, Union -from fastapi import Depends, Query, Response +from fastapi import Depends, Response from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase from trolldb.database.mongodb import MongoDB -exclude_defaults_query = Query( - True, - title="Query string", - description= - "A boolean to exclude default databases from a MongoDB instance. Refer to " - "`trolldb.database.mongodb.MongoDB.default_database_names` for more information." -) - async def check_database(database_name: str | None = None) -> AsyncIOMotorDatabase: """A dependency for route handlers to check for the existence of a database given its name. diff --git a/trolldb/api/routes/databases.py b/trolldb/api/routes/databases.py index e114ef4..046d57a 100644 --- a/trolldb/api/routes/databases.py +++ b/trolldb/api/routes/databases.py @@ -4,10 +4,12 @@ For more information on the API server, see the automatically generated documentation by FastAPI. """ -from fastapi import APIRouter +from typing import Annotated + +from fastapi import APIRouter, Query from pymongo.collection import _DocumentType -from trolldb.api.routes.common import CheckCollectionDependency, CheckDataBaseDependency, exclude_defaults_query +from trolldb.api.routes.common import CheckCollectionDependency, CheckDataBaseDependency from trolldb.config.config import MongoObjectId from trolldb.database.errors import ( Databases, @@ -23,7 +25,12 @@ @router.get("/", response_model=list[str], summary="Gets the list of all database names") -async def database_names(exclude_defaults: bool = exclude_defaults_query) -> list[str]: +async def database_names( + exclude_defaults: Annotated[bool, Query( + title="Query parameter", + description="A boolean to exclude default databases from a MongoDB instance. Refer to " + "`trolldb.database.mongodb.MongoDB.default_database_names` for more information." + )] = True) -> list[str]: """Please consult the auto-generated documentation by FastAPI.""" db_names = await MongoDB.list_database_names() diff --git a/trolldb/api/routes/queries.py b/trolldb/api/routes/queries.py index e585101..4542b57 100644 --- a/trolldb/api/routes/queries.py +++ b/trolldb/api/routes/queries.py @@ -5,13 +5,14 @@ """ import datetime +from typing import Annotated from fastapi import APIRouter, Query from trolldb.api.routes.common import CheckCollectionDependency from trolldb.database.errors import database_collection_error_descriptor from trolldb.database.mongodb import get_ids -from trolldb.database.piplines import PipelineAttribute, Pipelines +from trolldb.database.pipelines import PipelineAttribute, Pipelines router = APIRouter() @@ -22,14 +23,11 @@ summary="Gets the database UUIDs of the documents that match specifications determined by the query string") async def queries( collection: CheckCollectionDependency, - # We suppress ruff for the following four lines with `Query(default=None)`. - # Reason: This is the FastAPI way of defining optional queries and ruff is not happy about it! - platform: list[str] = Query(default=None), # noqa: B008 - sensor: list[str] = Query(default=None), # noqa: B008 - time_min: datetime.datetime = Query(default=None), # noqa: B008 - time_max: datetime.datetime = Query(default=None)) -> list[str]: # noqa: B008 + platform: Annotated[list[str] | None, Query()] = None, + sensor: Annotated[list[str] | None, Query()] = None, + time_min: Annotated[datetime.datetime, Query()] = None, + time_max: Annotated[datetime.datetime, Query()] = None) -> list[str]: """Please consult the auto-generated documentation by FastAPI.""" - # We pipelines = Pipelines() if platform: @@ -42,10 +40,7 @@ async def queries( start_time = PipelineAttribute("start_time") end_time = PipelineAttribute("end_time") pipelines += ( - (start_time >= time_min) | - (start_time <= time_max) | - (end_time >= time_min) | - (end_time <= time_max) + ((start_time >= time_min) & (start_time <= time_max)) | + ((end_time >= time_min) & (end_time <= time_max)) ) - return await get_ids(collection.aggregate(pipelines)) diff --git a/trolldb/cli.py b/trolldb/cli.py index 294aab9..dd5f0c2 100644 --- a/trolldb/cli.py +++ b/trolldb/cli.py @@ -4,15 +4,40 @@ import asyncio from loguru import logger +from motor.motor_asyncio import AsyncIOMotorCollection from posttroll.message import Message from posttroll.subscriber import create_subscriber_from_dict_config from pydantic import FilePath -from trolldb.config.config import AppConfig, parse_config_yaml_file +from trolldb.config.config import AppConfig, parse_config from trolldb.database.mongodb import MongoDB, mongodb_context -async def record_messages(config: AppConfig): +async def delete_uri_from_collection(collection: AsyncIOMotorCollection, uri: str) -> int: + """Deletes a document from collection and logs the deletion. + + Args: + collection: + The collection object which includes the document to delete. + uri: + The URI used to query the collection. It can be either a URI of a previously recorded file message or + a dataset message. + + Returns: + Number of deleted documents. + """ + del_result_file = await collection.delete_many({"uri": uri}) + if del_result_file.deleted_count == 1: + logger.info(f"Deleted one document (file) with uri: {uri}") + + del_result_dataset = await collection.delete_many({"dataset.uri": uri}) + if del_result_dataset.deleted_count == 1: + logger.info(f"Deleted one document (dataset) with uri: {uri}") + + return del_result_file.deleted_count + del_result_dataset.deleted_count + + +async def record_messages(config: AppConfig) -> None: """Record the metadata of messages into the database.""" async with mongodb_context(config.database): collection = await MongoDB.get_collection( @@ -20,23 +45,27 @@ async def record_messages(config: AppConfig): ) for m in create_subscriber_from_dict_config(config.subscriber).recv(): msg = Message.decode(str(m)) - if msg.type in ["file", "dataset"]: - await collection.insert_one(msg.data) - elif msg.type == "del": - deletion_result = await collection.delete_many({"uri": msg.data["uri"]}) - if deletion_result.deleted_count != 1: - logger.error("Recorder found multiple deletions!") # TODO: Log some data related to the msg - else: - logger.debug(f"Don't know what to do with {msg.type} message.") + match msg.type: + case "file": + await collection.insert_one(msg.data) + logger.info(f"Inserted file with uri: {msg.data["uri"]}") + case "dataset": + await collection.insert_one(msg.data) + logger.info(f"Inserted dataset with {len(msg.data["dataset"])} elements: {msg.data["dataset"]}") + case "del": + deletion_count = await delete_uri_from_collection(collection, msg.data["uri"]) + if deletion_count > 1: + logger.error(f"Recorder found multiple deletions for uri: {msg.data["uri"]}!") + case _: + logger.debug(f"Don't know what to do with {msg.type} message.") -async def record_messages_from_config(config_file: FilePath): +async def record_messages_from_config(config_file: FilePath) -> None: """Record messages into the database, getting the configuration from a file.""" - config = parse_config_yaml_file(config_file) - await record_messages(config) + await record_messages(parse_config(config_file)) -async def record_messages_from_command_line(args=None): +async def record_messages_from_command_line(args=None) -> None: """Record messages into the database, command-line interface.""" parser = argparse.ArgumentParser() parser.add_argument( @@ -47,6 +76,6 @@ async def record_messages_from_command_line(args=None): await record_messages_from_config(cmd_args.configuration_file) -def run_sync(): +def run_sync() -> None: """Runs the interface synchronously.""" asyncio.run(record_messages_from_command_line()) diff --git a/trolldb/config/config.py b/trolldb/config/config.py index b43d7fe..99ea99e 100644 --- a/trolldb/config/config.py +++ b/trolldb/config/config.py @@ -3,9 +3,9 @@ The validation is performed using `Pydantic `_. Note: - Functions in this module are decorated with - `pydantic.validate_call `_ - so that their arguments can be validated using the corresponding type hints, when calling the function at runtime. + Some functions/methods in this module are decorated with the Pydantic + `@validate_call `_ which checks the arguments during the + function calls. """ import errno @@ -15,15 +15,16 @@ from bson import ObjectId from bson.errors import InvalidId from loguru import logger -from pydantic import AnyUrl, BaseModel, Field, FilePath, MongoDsn, ValidationError +from pydantic import AnyUrl, BaseModel, MongoDsn, PositiveFloat, ValidationError, validate_call from pydantic.functional_validators import AfterValidator from typing_extensions import Annotated from yaml import safe_load -Timeout = Annotated[float, Field(ge=0)] +Timeout = PositiveFloat """A type hint for the timeout in seconds (non-negative float).""" +@validate_call def id_must_be_valid(id_like_string: str) -> ObjectId: """Checks that the given string can be converted to a valid MongoDB ObjectId. @@ -32,9 +33,12 @@ def id_must_be_valid(id_like_string: str) -> ObjectId: The string to be converted to an ObjectId. Returns: - The ObjectId object if successfully. + The ObjectId object if successful. Raises: + ValidationError: + If the given argument is not of type ``str``. + ValueError: If the given string cannot be converted to a valid ObjectId. This will ultimately turn into a pydantic validation error. @@ -46,7 +50,7 @@ def id_must_be_valid(id_like_string: str) -> ObjectId: MongoObjectId = Annotated[str, AfterValidator(id_must_be_valid)] -"""Type hint validator for object IDs.""" +"""The type hint validator for object IDs.""" class MongoDocument(BaseModel): @@ -60,7 +64,7 @@ class APIServerConfig(NamedTuple): Note: The attributes herein are a subset of the keyword arguments accepted by `FastAPI class `_ and are directly passed - to the FastAPI class. + to the FastAPI class. Consult :func:`trolldb.api.api.run_server` on how these configurations are treated. """ url: AnyUrl @@ -79,7 +83,7 @@ class DatabaseConfig(NamedTuple): """ url: MongoDsn - """The URL of the MongoDB server excluding the port part, e.g. ``"mongodb://localhost:27017"``""" + """The URL of the MongoDB server including the port part, e.g. ``"mongodb://localhost:27017"``""" timeout: Timeout """The timeout in seconds (non-negative float), after which an exception is raised if a connection with the @@ -95,7 +99,7 @@ class DatabaseConfig(NamedTuple): class AppConfig(BaseModel): - """A model to hold all the configurations of the application including both the API server and the database. + """A model to hold all the configurations of the application, i.e. the API server, the database, and the subscriber. This will be used by Pydantic to validate the parsed YAML file. """ @@ -104,36 +108,30 @@ class AppConfig(BaseModel): subscriber: SubscriberConfig -def parse_config_yaml_file(filename: FilePath) -> AppConfig: - """Parses and validates the configurations from a YAML file. +@logger.catch(onerror=lambda _: sys.exit(1)) +def parse_config(file) -> AppConfig: + """Parses and validates the configurations from a YAML file (descriptor). Args: - filename: - The filename of a valid YAML file which holds the configurations. + file: + A `path-like object `_ or an integer file + descriptor. This will be directly passed to the ``open()`` function. For example, it can be the filename + (absolute or relative) of a valid YAML file which holds the configurations. Returns: An instance of :class:`AppConfig`. - - Raises: - ParserError: - If the file cannot be properly parsed - - ValidationError: - If the successfully parsed file fails the validation, i.e. its schema or the content does not conform to - :class:`AppConfig`. - - ValidationError: - If the function is not called with arguments of valid type. """ logger.info("Attempt to parse the YAML file ...") - with open(filename, "r") as file: - config = safe_load(file) + with open(file, "r") as f: + config = safe_load(f) logger.info("Parsing YAML file is successful.") + try: logger.info("Attempt to validate the parsed YAML file ...") config = AppConfig(**config) - logger.info("Validation of the parsed YAML file is successful.") - return config except ValidationError as e: logger.error(e) sys.exit(errno.EIO) + + logger.info("Validation of the parsed YAML file is successful.") + return config diff --git a/trolldb/database/errors.py b/trolldb/database/errors.py index b3830e3..7a14608 100644 --- a/trolldb/database/errors.py +++ b/trolldb/database/errors.py @@ -6,70 +6,72 @@ are (expected to be) self-explanatory and require no additional documentation. """ +from typing import ClassVar + from fastapi import status from trolldb.errors.errors import ResponseError, ResponsesErrorGroup class Client(ResponsesErrorGroup): - """Client error responses, e.g. if something goes wrong with initialization or closing the client.""" - CloseNotAllowedError = ResponseError({ + """Database client error responses, e.g. if something goes wrong with initialization or closing the client.""" + CloseNotAllowedError: ClassVar[ResponseError] = ResponseError({ status.HTTP_405_METHOD_NOT_ALLOWED: "Calling `close()` on a client which has not been initialized is not allowed!" }) - ReinitializeConfigError = ResponseError({ + ReinitializeConfigError: ClassVar[ResponseError] = ResponseError({ status.HTTP_405_METHOD_NOT_ALLOWED: "The client is already initialized with a different database configuration!" }) - AlreadyOpenError = ResponseError({ + AlreadyOpenError: ClassVar[ResponseError] = ResponseError({ status.HTTP_100_CONTINUE: "The client has been already initialized with the same configuration." }) - InconsistencyError = ResponseError({ + InconsistencyError: ClassVar[ResponseError] = ResponseError({ status.HTTP_405_METHOD_NOT_ALLOWED: "Something must have been wrong as we are in an inconsistent state. " "The internal database configuration is not empty and is the same as what we just " "received but the client is `None` or has been already closed!" }) - ConnectionError = ResponseError({ + ConnectionError: ClassVar[ResponseError] = ResponseError({ status.HTTP_400_BAD_REQUEST: - "Could not connect to the database with URL." + "Could not connect to the database with the given URL." }) class Collections(ResponsesErrorGroup): - """Collections error responses, e.g. if a requested collection cannot be found.""" - NotFoundError = ResponseError({ + """Collections error responses, e.g. if the requested collection cannot be found.""" + NotFoundError: ClassVar[ResponseError] = ResponseError({ status.HTTP_404_NOT_FOUND: "Could not find the given collection name inside the specified database." }) - WrongTypeError = ResponseError({ + WrongTypeError: ClassVar[ResponseError] = ResponseError({ status.HTTP_422_UNPROCESSABLE_ENTITY: - "Both the Database and collection name must be `None` if one of them is `None`." + "Both the database and collection name must be `None` if either one is `None`." }) class Databases(ResponsesErrorGroup): - """Databases error responses, e.g. if a requested database cannot be found.""" - NotFoundError = ResponseError({ + """Databases error responses, e.g. if the requested database cannot be found.""" + NotFoundError: ClassVar[ResponseError] = ResponseError({ status.HTTP_404_NOT_FOUND: "Could not find the given database name." }) - WrongTypeError = ResponseError({ + WrongTypeError: ClassVar[ResponseError] = ResponseError({ status.HTTP_422_UNPROCESSABLE_ENTITY: "Database name must be either of type `str` or `None.`" }) class Documents(ResponsesErrorGroup): - """Documents error responses, e.g. if a requested document cannot be found.""" - NotFound = ResponseError({ + """Documents error responses, e.g. if the requested document cannot be found.""" + NotFound: ClassVar[ResponseError] = ResponseError({ status.HTTP_404_NOT_FOUND: "Could not find any document with the given object id." }) diff --git a/trolldb/database/mongodb.py b/trolldb/database/mongodb.py index 3ec5a79..4f8f3bc 100644 --- a/trolldb/database/mongodb.py +++ b/trolldb/database/mongodb.py @@ -3,11 +3,16 @@ It is based on the following libraries: - `PyMongo `_ - `motor `_. + +Note: + Some functions/methods in this module are decorated with the Pydantic + `@validate_call `_ which checks the arguments during the + function calls. """ import errno from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Coroutine, Optional, TypeVar, Union +from typing import Any, AsyncGenerator, ClassVar, Coroutine, Optional, TypeVar, Union from loguru import logger from motor.motor_asyncio import ( @@ -17,13 +22,12 @@ AsyncIOMotorCursor, AsyncIOMotorDatabase, ) -from pydantic import BaseModel +from pydantic import validate_call from pymongo.collection import _DocumentType from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError from trolldb.config.config import DatabaseConfig from trolldb.database.errors import Client, Collections, Databases -from trolldb.errors.errors import ResponseError T = TypeVar("T") CoroutineLike = Coroutine[Any, Any, T] @@ -36,22 +40,12 @@ """Coroutine type hint for a list of strings.""" -class DatabaseName(BaseModel): - """Pydantic model for a database name.""" - name: str | None - - -class CollectionName(BaseModel): - """Pydantic model for a collection name.""" - name: str | None - - async def get_id(doc: CoroutineDocument) -> str: """Retrieves the ID of a document as a simple flat string. Note: The rationale behind this method is as follows. In MongoDB, each document has a unique ID which is of type - :class:`~bson.objectid.ObjectId`. This is not suitable for purposes when a simple string is needed, hence + :class:`bson.objectid.ObjectId`. This is not suitable for purposes when a simple string is needed, hence the need for this method. Args: @@ -84,7 +78,7 @@ async def get_ids(docs: Union[AsyncIOMotorCommandCursor, AsyncIOMotorCursor]) -> class MongoDB: """A wrapper class around the `motor async driver `_ for Mongo DB. - It includes convenience methods tailored to our specific needs. As such, the :func:`~MongoDB.initialize()`` method + It includes convenience methods tailored to our specific needs. As such, the :func:`~MongoDB.initialize()` method returns a coroutine which needs to be awaited. Note: @@ -103,12 +97,12 @@ class MongoDB: us, we would like to fail early! """ - __client: Optional[AsyncIOMotorClient] = None - __database_config: Optional[DatabaseConfig] = None - __main_collection: AsyncIOMotorCollection = None - __main_database: AsyncIOMotorDatabase = None + __client: ClassVar[Optional[AsyncIOMotorClient]] = None + __database_config: ClassVar[Optional[DatabaseConfig]] = None + __main_collection: ClassVar[Optional[AsyncIOMotorCollection]] = None + __main_database: ClassVar[Optional[AsyncIOMotorDatabase]] = None - default_database_names = ["admin", "config", "local"] + default_database_names: ClassVar[list[str]] = ["admin", "config", "local"] """MongoDB creates these databases by default for self usage.""" @classmethod @@ -117,25 +111,29 @@ async def initialize(cls, database_config: DatabaseConfig): Args: database_config: - A named tuple which includes the database configurations. + An object of type :class:`~trolldb.config.config.DatabaseConfig` which includes the database + configurations. + + Warning: + The timeout is given in seconds in the configurations, while the MongoDB uses milliseconds. Returns: On success ``None``. Raises: SystemExit(errno.EIO): - If connection is not established (``ConnectionFailure``) + If connection is not established, i.e. ``ConnectionFailure``. SystemExit(errno.EIO): - If the attempt times out (``ServerSelectionTimeoutError``) + If the attempt times out, i.e. ``ServerSelectionTimeoutError``. SystemExit(errno.EIO): If one attempts reinitializing the class with new (different) database configurations without calling :func:`~close()` first. SystemExit(errno.EIO): If the state is not consistent, i.e. the client is closed or ``None`` but the internal database configurations still exist and are different from the new ones which have been just provided. - SystemExit(errno.ENODATA): - If either ``database_config.main_database`` or ``database_config.main_collection`` does not exist. + If either ``database_config.main_database_name`` or ``database_config.main_collection_name`` does not + exist. """ logger.info("Attempt to initialize the MongoDB client ...") logger.info("Checking the database configs ...") @@ -220,10 +218,11 @@ def main_database(cls) -> AsyncIOMotorDatabase: return cls.__main_database @classmethod + @validate_call async def get_collection( cls, - database_name: str, - collection_name: str) -> Union[AsyncIOMotorCollection, ResponseError]: + database_name: str | None, + collection_name: str | None) -> AsyncIOMotorCollection: """Gets the collection object given its name and the database name in which it resides. Args: @@ -238,7 +237,7 @@ async def get_collection( Raises: ValidationError: - If input args are invalid according to the pydantic. + If the method is not called with arguments of valid type. KeyError: If the database name exists, but it does not include any collection with the given name. @@ -250,9 +249,6 @@ async def get_collection( This method relies on :func:`get_database` to check for the existence of the database which can raise exceptions. Check its documentation for more information. """ - database_name = DatabaseName(name=database_name).name - collection_name = CollectionName(name=collection_name).name - match database_name, collection_name: case None, None: return cls.main_collection() @@ -266,7 +262,8 @@ async def get_collection( raise Collections.WrongTypeError @classmethod - async def get_database(cls, database_name: str) -> Union[AsyncIOMotorDatabase, ResponseError]: + @validate_call + async def get_database(cls, database_name: str | None) -> AsyncIOMotorDatabase: """Gets the database object given its name. Args: @@ -277,11 +274,12 @@ async def get_database(cls, database_name: str) -> Union[AsyncIOMotorDatabase, R The database object. Raises: - KeyError: + ValidationError: + If the method is not called with arguments of valid type. + + KeyError: If the database name does not exist in the list of database names. """ - database_name = DatabaseName(name=database_name).name - match database_name: case None: return cls.main_database() diff --git a/trolldb/database/piplines.py b/trolldb/database/pipelines.py similarity index 97% rename from trolldb/database/piplines.py rename to trolldb/database/pipelines.py index f85fa15..237637c 100644 --- a/trolldb/database/piplines.py +++ b/trolldb/database/pipelines.py @@ -30,11 +30,11 @@ class PipelineBooleanDict(dict): pd_or == pd_or_literal """ - def __or__(self, other: Self): + def __or__(self, other: Self) -> Self: """Implements the bitwise or operator, i.e. ``|``.""" return PipelineBooleanDict({"$or": [self, other]}) - def __and__(self, other: Self): + def __and__(self, other: Self) -> Self: """Implements the bitwise and operator, i.e. ``&``.""" return PipelineBooleanDict({"$and": [self, other]}) @@ -112,7 +112,7 @@ class Pipelines(list): Each item in the list is a dictionary with its key being the literal string ``"$match"`` and its corresponding value being of type :class:`PipelineBooleanDict`. The ``"$match"`` key is what actually triggers the matching operation in the MongoDB aggregation pipeline. The condition against which the matching will be performed is given by the value - which is a simply a boolean pipeline dictionary which has a hierarchical structure. + which is a simply a boolean pipeline dictionary and has a hierarchical structure. Example: .. code-block:: python diff --git a/trolldb/errors/errors.py b/trolldb/errors/errors.py index 95fd6bf..e0ea80b 100644 --- a/trolldb/errors/errors.py +++ b/trolldb/errors/errors.py @@ -6,7 +6,7 @@ from collections import OrderedDict from sys import exit -from typing import Self +from typing import ClassVar, NoReturn, Self from fastapi import Response from fastapi.responses import PlainTextResponse @@ -31,9 +31,9 @@ def _listify(item: str | list[str]) -> list[str]: .. code-block:: python # The following evaluate to True - __listify("test") == ["test"] - __listify(["a", "b"]) = ["a", "b"] - __listify([]) == [] + _listify("test") == ["test"] + _listify(["a", "b"]) = ["a", "b"] + _listify([]) == [] """ return item if isinstance(item, list) else [item] @@ -64,7 +64,7 @@ class ResponseError(Exception): messages. """ - descriptor_delimiter: str = " |OR| " + descriptor_delimiter: ClassVar[str] = " |OR| " """A delimiter to divide the message part of several error responses which have been combined into a single one. This will be shown in textual format for the response descriptors of the Fast API routes. @@ -76,12 +76,11 @@ class ResponseError(Exception): error_b = ResponseError({404: "Not Found"}) errors = error_a | error_b - # When used in a FastAPI response descriptor, - # the following string will be generated for errors + # When used in a FastAPI response descriptor, the following string is generated "Bad Request |OR| Not Found" """ - DefaultResponseClass: Response = PlainTextResponse + DefaultResponseClass: ClassVar[Response] = PlainTextResponse """The default type of the response which will be returned when an error occurs. This must be a valid member (class) of ``fastapi.responses``. @@ -101,12 +100,12 @@ def __init__(self, args_dict: OrderedDict[StatusCode, str | list[str]] | dict) - error_b = ResponseError({404: "Not Found"}) errors = error_a | error_b errors_a_or_b = ResponseError({400: "Bad Request", 404: "Not Found"}) - errors_list = ResponseError({404: ["Not Found", "Still Not Found"]}) + errors_list = ResponseError({404: ["Not Found", "Yet Not Found"]}) """ self.__dict: OrderedDict = OrderedDict(args_dict) self.extra_information: dict | None = None - def __or__(self, other: Self): + def __or__(self, other: Self) -> Self: """Implements the bitwise `or` ``|`` which combines the error objects into a single error response. Args: @@ -142,7 +141,7 @@ def __or__(self, other: Self): def __retrieve_one_from_some( self, - status_code: StatusCode | None = None) -> (StatusCode, str): + status_code: StatusCode | None = None) -> tuple[StatusCode, str]: """Retrieves a tuple ``(, )`` from the internal dictionary :obj:`ResponseError.__dict`. Args: @@ -183,12 +182,12 @@ def __retrieve_one_from_some( def get_error_details( self, extra_information: dict | None = None, - status_code: int | None = None) -> (StatusCode, str): + status_code: int | None = None) -> tuple[StatusCode, str]: """Gets the details of the error response. Args: extra_information (Optional, default ``None``): - More information (if any) that wants to be added to the message string. + More information (if any) that needs to be added to the message string. status_code (Optional, default ``None``): The status code to retrieve. This is useful when there are several error items in the internal dictionary. In case of ``None``, the internal dictionary must include a single entry, otherwise an error @@ -203,7 +202,7 @@ def get_error_details( def log_as_warning( self, extra_information: dict | None = None, - status_code: int | None = None): + status_code: int | None = None) -> None: """Same as :func:`~ResponseError.get_error_details` but logs the error as a warning and returns ``None``.""" msg, _ = self.get_error_details(extra_information, status_code) logger.warning(msg) @@ -212,7 +211,7 @@ def sys_exit_log( self, exit_code: int = -1, extra_information: dict | None = None, - status_code: int | None = None) -> None: + status_code: int | None = None) -> NoReturn: """Same as :func:`~ResponseError.get_error_details` but logs the error and calls the ``sys.exit``. The arguments are the same as :func:`~ResponseError.get_error_details` with the addition of ``exit_code`` @@ -233,6 +232,11 @@ def sys_exit_log( def fastapi_descriptor(self) -> dict[StatusCode, dict[str, str]]: """Gets the FastAPI descriptor (dictionary) of the error items stored in :obj:`ResponseError.__dict`. + Note: + Consult the FastAPI documentation for + `additional responses `_ to see why and how + descriptors are used. + Example: .. code-block:: python diff --git a/trolldb/test_utils/__init__.py b/trolldb/test_utils/__init__.py index e1fa351..9f3e45a 100644 --- a/trolldb/test_utils/__init__.py +++ b/trolldb/test_utils/__init__.py @@ -1 +1 @@ -"""This package provide tools to test the database and api packages.""" +"""This package provides tools to test the database and api packages.""" diff --git a/trolldb/test_utils/common.py b/trolldb/test_utils/common.py index 6e3fb21..2007977 100644 --- a/trolldb/test_utils/common.py +++ b/trolldb/test_utils/common.py @@ -10,7 +10,7 @@ from trolldb.config.config import AppConfig -def make_test_app_config(subscriber_address: Optional[FilePath] = None) -> dict: +def make_test_app_config(subscriber_address: Optional[FilePath] = None) -> dict[str, dict]: """Makes the app configuration when used in testing. Args: @@ -19,15 +19,15 @@ def make_test_app_config(subscriber_address: Optional[FilePath] = None) -> dict: config will be an empty dictionary. Returns: - A dictionary which resembles an object of type :obj:`AppConfig`. + A dictionary which resembles an object of type :obj:`~trolldb.config.config.AppConfig`. """ app_config = dict( api_server=dict( url="http://localhost:8080" ), database=dict( - main_database_name="mock_database", - main_collection_name="mock_collection", + main_database_name="test_database", + main_collection_name="test_collection", url="mongodb://localhost:28017", timeout=1 ), diff --git a/trolldb/test_utils/mongodb_database.py b/trolldb/test_utils/mongodb_database.py index d8060e9..3856348 100644 --- a/trolldb/test_utils/mongodb_database.py +++ b/trolldb/test_utils/mongodb_database.py @@ -1,9 +1,9 @@ """The module which provides testing utilities to make MongoDB databases/collections and fill them with test data.""" - from contextlib import contextmanager +from copy import deepcopy from datetime import datetime, timedelta from random import choices, randint, shuffle -from typing import Iterator +from typing import Any, ClassVar, Generator from pymongo import MongoClient @@ -12,7 +12,8 @@ @contextmanager -def mongodb_for_test_context(database_config: DatabaseConfig = test_app_config.database) -> Iterator[MongoClient]: +def mongodb_for_test_context( + database_config: DatabaseConfig = test_app_config.database) -> Generator[MongoClient, Any, None]: """A context manager for the MongoDB client given test configurations. Note: @@ -37,15 +38,15 @@ def mongodb_for_test_context(database_config: DatabaseConfig = test_app_config.d class Time: - """A static class to enclose functionalities for generating random time stamps.""" + """A static class to enclose functionalities for generating random timestamps.""" - min_start_time = datetime(2019, 1, 1, 0, 0, 0) + min_start_time: ClassVar[datetime] = datetime(2019, 1, 1, 0, 0, 0) """The minimum timestamp which is allowed to appear in our data.""" - max_end_time = datetime(2024, 1, 1, 0, 0, 0) + max_end_time: ClassVar[datetime] = datetime(2024, 1, 1, 0, 0, 0) """The maximum timestamp which is allowed to appear in our data.""" - delta_time = int((max_end_time - min_start_time).total_seconds()) + delta_time: ClassVar[int] = int((max_end_time - min_start_time).total_seconds()) """The difference between the maximum and minimum timestamps in seconds.""" @staticmethod @@ -86,8 +87,8 @@ def __init__(self, platform_name: str, sensor: str) -> None: def generate_dataset(self, max_count: int) -> list[dict]: """Generates the dataset for a given document. - This corresponds to the list of files which are stored in each document. The number of datasets is randomly - chosen from 1 to ``max_count`` for each document. + This corresponds to the list of files which are stored in each document. The number of items in a dataset is + randomly chosen from 1 to ``max_count`` for each document. """ dataset = [] # We suppress ruff (S311) here as we are not generating anything cryptographic here! @@ -113,54 +114,71 @@ def like_mongodb_document(self) -> dict: class TestDatabase: - """A static class which encloses functionalities to prepare and fill the test database with mock data.""" + """A static class which encloses functionalities to prepare and fill the test database with test data.""" + + unique_platform_names: ClassVar[list[str]] = ["PA", "PB", "PC"] + """The unique platform names that will be used to generate the sample of all platform names.""" # We suppress ruff (S311) here as we are not generating anything cryptographic here! - platform_names = choices(["PA", "PB", "PC"], k=10) # noqa: S311 - """Example platform names.""" + platform_names: ClassVar[list[str]] = choices(["PA", "PB", "PC"], k=20) # noqa: S311 + """Example platform names. + + Warning: + The value of this variable changes randomly every time. What you see above is just an example which has been + generated as a result of building the documentation! + """ + + unique_sensors: ClassVar[list[str]] = ["SA", "SB", "SC"] + """The unique sensor names that will be used to generate the sample of all sensor names.""" # We suppress ruff (S311) here as we are not generating anything cryptographic here! - sensors = choices(["SA", "SB", "SC"], k=10) # noqa: S311 - """Example sensor names.""" + sensors: ClassVar[list[str]] = choices(["SA", "SB", "SC"], k=20) # noqa: S311 + """Example sensor names. - database_names = [test_app_config.database.main_database_name, "another_mock_database"] + Warning: + The value of this variable changes randomly every time. What you see above is just an example which has been + generated as a result of building the documentation! + """ + + database_names: ClassVar[list[str]] = [test_app_config.database.main_database_name, "another_test_database"] """List of all database names. - The first element is the main database that will be queried by the API and includes the mock data. The second + The first element is the main database that will be queried by the API and includes the test data. The second database is for testing scenarios when one attempts to access another existing database or collection. """ - collection_names = [test_app_config.database.main_collection_name, "another_mock_collection"] + collection_names: ClassVar[list[str]] = [test_app_config.database.main_collection_name, "another_test_collection"] """List of all collection names. - The first element is the main collection that will be queried by the API and includes the mock data. The second + The first element is the main collection that will be queried by the API and includes the test data. The second collection is for testing scenarios when one attempts to access another existing collection. """ - all_database_names = ["admin", "config", "local", *database_names] + all_database_names: ClassVar[list[str]] = ["admin", "config", "local", *database_names] """All database names including the default ones which are automatically created by MongoDB.""" - documents: list[dict] = [] - """The list of documents which include mock data.""" + documents: ClassVar[list[dict]] = [] + """The list of documents which include test data.""" @classmethod def generate_documents(cls, random_shuffle: bool = True) -> None: """Generates test documents which for practical purposes resemble real data. Warning: - This method is not pure! The side effect is that the :obj:`TestDatabase.documents` is filled. + This method is not pure! The side effect is that the :obj:`TestDatabase.documents` is reset to new values. """ cls.documents = [ - Document(p, s).like_mongodb_document() for p, s in zip(cls.platform_names, cls.sensors, strict=False)] + Document(p, s).like_mongodb_document() for p, s in zip(cls.platform_names, cls.sensors, strict=False) + ] if random_shuffle: shuffle(cls.documents) @classmethod - def reset(cls): + def reset(cls) -> None: """Resets all the databases/collections. - This is done by deleting all documents in the collections and then inserting a single empty ``{}`` document - in them. + This is done by deleting all documents in the collections and then inserting a single empty document, i.e. + ``{}``, in them. """ with mongodb_for_test_context() as client: for db_name, coll_name in zip(cls.database_names, cls.collection_names, strict=False): @@ -170,8 +188,8 @@ def reset(cls): collection.insert_one({}) @classmethod - def write_mock_date(cls): - """Fills databases/collections with mock data.""" + def write_test_data(cls) -> None: + """Fills databases/collections with test data.""" with mongodb_for_test_context() as client: # The following function call has side effects! cls.generate_documents() @@ -180,10 +198,112 @@ def write_mock_date(cls): ][ test_app_config.database.main_collection_name ] + collection.delete_many({}) collection.insert_many(cls.documents) @classmethod - def prepare(cls): - """Prepares the MongoDB instance by first resetting the database and then filling it with mock data.""" + def get_all_documents_from_database(cls) -> list[dict]: + """Retrieves all the documents from the database. + + Returns: + A list of all documents from the database. This matches the content of :obj:`~TestDatabase.documents` with + the addition of `IDs` which are assigned by the MongoDB. + """ + with mongodb_for_test_context() as client: + collection = client[ + test_app_config.database.main_database_name + ][ + test_app_config.database.main_collection_name + ] + documents = list(collection.find({})) + return documents + + @classmethod + def find_min_max_datetime(cls) -> dict[str, dict]: + """Finds the minimum and the maximum for both the ``start_time`` and the ``end_time``. + + We use `brute force` for this purpose. We set the minimum to a large value (year 2100) and the maximum to a + small value (year 1900). We then iterate through all documents and update the extrema. + + Returns: + A dictionary whose schema matches the response returned by the ``/datetime`` route of the API. + """ + result = dict( + start_time=dict( + _min=dict(_id=None, _time="2100-01-01T00:00:00"), + _max=dict(_id=None, _time="1900-01-01T00:00:00") + ), + end_time=dict( + _min=dict(_id=None, _time="2100-01-01T00:00:00"), + _max=dict(_id=None, _time="1900-01-01T00:00:00")) + ) + + documents = cls.get_all_documents_from_database() + + for document in documents: + for k in ["start_time", "end_time"]: + dt = document[k].isoformat() + if dt > result[k]["_max"]["_time"]: + result[k]["_max"]["_time"] = dt + result[k]["_max"]["_id"] = str(document["_id"]) + + if dt < result[k]["_min"]["_time"]: + result[k]["_min"]["_time"] = dt + result[k]["_min"]["_id"] = str(document["_id"]) + + return result + + @classmethod + def _query_platform_sensor(cls, document, platform=None, sensor=None) -> bool: + """An auxiliary method to the :func:`TestDatabase.match_query`.""" + should_remove = False + + if platform: + should_remove = platform and document["platform_name"] not in platform + + if sensor and not should_remove: + should_remove = document["sensor"] not in sensor + + return should_remove + + @classmethod + def _query_time(cls, document, time_min=None, time_max=None) -> bool: + """An auxiliary method to the :func:`TestDatabase.match_query`.""" + should_remove = False + + if time_min and time_max and not should_remove: + should_remove = document["end_time"] < time_min or document["start_time"] > time_max + + if time_min and not time_max and not should_remove: + should_remove = document["end_time"] < time_min + + if time_max and not time_min and not should_remove: + should_remove = document["end_time"] > time_max + + return should_remove + + @classmethod + def match_query(cls, platform=None, sensor=None, time_min=None, time_max=None) -> list[str]: + """Matches the given query. + + We first take all the documents and then progressively remove all that do not match the given queries until + we end up with those that match. When a query is ``None``, it does not have any effect on the results. + This method will be used in testing the ``/queries`` route of the API. + """ + documents = cls.get_all_documents_from_database() + + buffer = deepcopy(documents) + for document in documents: + should_remove = cls._query_platform_sensor(document, platform, sensor) + if not should_remove: + should_remove = cls._query_time(document, time_min, time_max) + if should_remove and document in buffer: + buffer.remove(document) + + return [str(item["_id"]) for item in buffer] + + @classmethod + def prepare(cls) -> None: + """Prepares the MongoDB instance by first resetting the database and filling it with generated test data.""" cls.reset() - cls.write_mock_date() + cls.write_test_data() diff --git a/trolldb/test_utils/mongodb_instance.py b/trolldb/test_utils/mongodb_instance.py index 1b16f04..433bf91 100644 --- a/trolldb/test_utils/mongodb_instance.py +++ b/trolldb/test_utils/mongodb_instance.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from os import mkdir, path from shutil import rmtree +from typing import Any, AnyStr, ClassVar, Generator, Optional from loguru import logger @@ -18,32 +19,34 @@ class TestMongoInstance: """A static class to enclose functionalities for running a MongoDB instance.""" - log_dir: str = tempfile.mkdtemp("__pytroll_db_temp_test_log") + log_dir: ClassVar[str] = tempfile.mkdtemp("__pytroll_db_temp_test_log") """Temp directory for logging messages by the MongoDB instance. Warning: - The value of this attribute as shown above is just an example and will change in an unpredictable (secure) way! + The value of this attribute as shown above is just an example and will change in an unpredictable (secure) way + every time! """ - storage_dir: str = tempfile.mkdtemp("__pytroll_db_temp_test_storage") + storage_dir: ClassVar[str] = tempfile.mkdtemp("__pytroll_db_temp_test_storage") """Temp directory for storing database files by the MongoDB instance. Warning: - The value of this attribute as shown above is just an example and will change in an unpredictable (secure) way! + The value of this attribute as shown above is just an example and will change in an unpredictable (secure) way + every time! """ - port: int = 28017 + port: ClassVar[int] = 28017 """The port on which the instance will run. Warning: This must be always hard-coded. """ - process: subprocess.Popen | None = None + process: ClassVar[Optional[subprocess.Popen]] = None """The process which is used to run the MongoDB instance.""" @classmethod - def __prepare_dir(cls, directory: str): + def __prepare_dir(cls, directory: str) -> None: """An auxiliary function to prepare a single directory. It creates a directory if it does not exist, or removes it first if it exists and then recreates it. @@ -52,13 +55,13 @@ def __prepare_dir(cls, directory: str): mkdir(directory) @classmethod - def __remove_dir(cls, directory: str): + def __remove_dir(cls, directory: str) -> None: """An auxiliary function to remove a directory and all its content recursively.""" if path.exists(directory) and path.isdir(directory): rmtree(directory) @classmethod - def run_subprocess(cls, args: list[str], wait=True): + def run_subprocess(cls, args: list[str], wait=True) -> tuple[AnyStr, AnyStr] | None: """Runs the subprocess in shell given its arguments.""" # We suppress ruff (S603) here as we are not receiving any args from outside, e.g. port is hard-coded. # Therefore, sanitization of arguments is not required. @@ -83,14 +86,14 @@ def prepare_dirs(cls) -> None: cls.__prepare_dir(d) @classmethod - def run_instance(cls): + def run_instance(cls) -> None: """Runs the MongoDB instance and does not wait for it, i.e. the process runs in the background.""" cls.run_subprocess( ["mongod", "--dbpath", cls.storage_dir, "--logpath", f"{cls.log_dir}/mongod.log", "--port", f"{cls.port}"] , wait=False) @classmethod - def shutdown_instance(cls): + def shutdown_instance(cls) -> None: """Shuts down the MongoDB instance by terminating its process.""" cls.process.terminate() cls.process.wait() @@ -101,7 +104,7 @@ def shutdown_instance(cls): @contextmanager def mongodb_instance_server_process_context( database_config: DatabaseConfig = test_app_config.database, - startup_time: Timeout = 2): + startup_time: Timeout = 2) -> Generator[Any, Any, None]: """A synchronous context manager to run the MongoDB instance in a separate process (non-blocking). It uses the `subprocess `_ package. The main use case is @@ -131,7 +134,7 @@ def mongodb_instance_server_process_context( @contextmanager -def running_prepared_database_context(): +def running_prepared_database_context() -> Generator[Any, Any, None]: """A synchronous context manager to start and prepare a database instance for tests.""" with mongodb_instance_server_process_context(): TestDatabase.prepare() diff --git a/trolldb/tests/test_recorder.py b/trolldb/tests/test_recorder.py index 99546c6..7bacd3d 100644 --- a/trolldb/tests/test_recorder.py +++ b/trolldb/tests/test_recorder.py @@ -5,7 +5,12 @@ from posttroll.testing import patched_subscriber_recv from pytest_lazy_fixtures import lf -from trolldb.cli import record_messages, record_messages_from_command_line, record_messages_from_config +from trolldb.cli import ( + delete_uri_from_collection, + record_messages, + record_messages_from_command_line, + record_messages_from_config, +) from trolldb.database.mongodb import MongoDB, mongodb_context from trolldb.test_utils.common import AppConfig, create_config_file, make_test_app_config, test_app_config from trolldb.test_utils.mongodb_instance import running_prepared_database_context @@ -57,16 +62,6 @@ def config_file(tmp_path): return create_config_file(tmp_path) -async def message_in_database_and_delete_count_is_one(msg) -> bool: - """Checks if there is exactly one item in the database which matches the data of the message.""" - async with mongodb_context(test_app_config.database): - collection = await MongoDB.get_collection("mock_database", "mock_collection") - result = await collection.find_one(dict(scan_mode="EW")) - result.pop("_id") - deletion_result = await collection.delete_many({"uri": msg.data["uri"]}) - return result == msg.data and deletion_result.deleted_count == 1 - - @pytest.mark.parametrize(("function", "args"), [ (record_messages_from_config, lf("config_file")), (record_messages_from_command_line, [lf("config_file")]) @@ -80,6 +75,20 @@ async def test_record_from_cli_and_config(tmp_path, file_message, tmp_data_filen assert await message_in_database_and_delete_count_is_one(msg) +async def message_in_database_and_delete_count_is_one(msg: Message) -> bool: + """Checks if there is exactly one item in the database which matches the data of the message.""" + async with mongodb_context(test_app_config.database): + collection = await MongoDB.get_collection("test_database", "test_collection") + result = await collection.find_one(dict(scan_mode="EW")) + result.pop("_id") + uri = msg.data.get("uri") + if not uri: + uri = msg.data["dataset"][0]["uri"] + deletion_count = await delete_uri_from_collection(collection, uri) + + return result == msg.data and deletion_count == 1 + + async def test_record_messages(config_file, tmp_path, file_message, tmp_data_filename): """Tests that message recording adds a message to the database.""" config = AppConfig(**make_test_app_config(tmp_path)) @@ -97,17 +106,16 @@ async def test_record_deletes_message(tmp_path, file_message, del_message): with patched_subscriber_recv([file_message, del_message]): await record_messages(config) async with mongodb_context(config.database): - collection = await MongoDB.get_collection("mock_database", "mock_collection") + collection = await MongoDB.get_collection("test_database", "test_collection") result = await collection.find_one(dict(scan_mode="EW")) assert result is None + async def test_record_dataset_messages(tmp_path, dataset_message): - """Test recording a dataset message and deleting the file.""" + """Tests recording a dataset message and deleting the file.""" config = AppConfig(**make_test_app_config(tmp_path)) + msg = Message.decode(dataset_message) with running_prepared_database_context(): with patched_subscriber_recv([dataset_message]): await record_messages(config) - async with mongodb_context(config.database): - collection = await MongoDB.get_collection("mock_database", "mock_collection") - result = await collection.find_one(dict(scan_mode="EW")) - assert result is not None + assert await message_in_database_and_delete_count_is_one(msg) diff --git a/trolldb/tests/tests_api/test_api.py b/trolldb/tests/tests_api/test_api.py index c721345..5e47cf2 100644 --- a/trolldb/tests/tests_api/test_api.py +++ b/trolldb/tests/tests_api/test_api.py @@ -8,22 +8,16 @@ """ from collections import Counter +from datetime import datetime import pytest from fastapi import status -from trolldb.test_utils.common import http_get +from trolldb.test_utils.common import http_get, test_app_config from trolldb.test_utils.mongodb_database import TestDatabase, mongodb_for_test_context - -def collections_exists(test_collection_names: list[str], expected_collection_name: list[str]) -> bool: - """Checks if the test and expected list of collection names match.""" - return Counter(test_collection_names) == Counter(expected_collection_name) - - -def document_ids_are_correct(test_ids: list[str], expected_ids: list[str]) -> bool: - """Checks if the test (retrieved from the API) and expected list of (document) ids match.""" - return Counter(test_ids) == Counter(expected_ids) +main_database_name = test_app_config.database.main_database_name +main_collection_name = test_app_config.database.main_collection_name @pytest.mark.usefixtures("_test_server_fixture") @@ -74,8 +68,120 @@ def test_collections(): ) +def collections_exists(test_collection_names: list[str], expected_collection_name: list[str]) -> bool: + """Checks if the test and expected list of collection names match.""" + return Counter(test_collection_names) == Counter(expected_collection_name) + + +def document_ids_are_correct(test_ids: list[str], expected_ids: list[str]) -> bool: + """Checks if the test (retrieved from the API) and expected list of (document) ids match.""" + return Counter(test_ids) == Counter(expected_ids) + + @pytest.mark.usefixtures("_test_server_fixture") def test_collections_negative(): """Checks that the non-existing collections cannot be found.""" for database_name in TestDatabase.database_names: assert http_get(f"databases/{database_name}/non_existing_collection").status == status.HTTP_404_NOT_FOUND + + +@pytest.mark.usefixtures("_test_server_fixture") +def test_datetime(): + """Checks that the datetime route works properly.""" + assert http_get("datetime").json() == TestDatabase.find_min_max_datetime() + + +@pytest.mark.usefixtures("_test_server_fixture") +def test_queries_all(): + """Tests that the queries route returns all documents when no actual queries are given.""" + assert document_ids_are_correct( + http_get("queries").json(), + [str(doc["_id"]) for doc in TestDatabase.get_all_documents_from_database()] + ) + + +@pytest.mark.usefixtures("_test_server_fixture") +@pytest.mark.parametrize(("key", "values"), [ + ("platform", TestDatabase.unique_platform_names), + ("sensor", TestDatabase.unique_sensors) +]) +def test_queries_platform_or_sensor(key: str, values: list[str]): + """Tests the platform and sensor queries, one at a time. + + There is only a single key in the query, but it has multiple corresponding values. + """ + for i in range(len(values)): + assert query_results_are_correct( + [key], + [values[:i]] + ) + + +def make_query_string(keys: list[str], values_list: list[list[str] | datetime]) -> str: + """Makes a single query string for all the given queries.""" + query_buffer = [] + for key, value_list in zip(keys, values_list, strict=True): + query_buffer += [f"{key}={value}" for value in value_list] + return "&".join(query_buffer) + + +def query_results_are_correct(keys: list[str], values_list: list[list[str] | datetime]) -> bool: + """Checks if the retrieved result from querying the database via the API matches the expected result. + + There can be more than one query `key/value` pair. + + Args: + keys: + A list of all query keys, e.g. ``keys=["platform", "sensor"]`` + + values_list: + A list in which each element is a list of values itself. The `nth` element corresponds to the `nth` key in + the ``keys``. + + Returns: + A boolean flag indicating whether the retrieved result matches the expected result. + """ + query_string = make_query_string(keys, values_list) + + return ( + Counter(http_get(f"queries?{query_string}").json()) == + Counter(TestDatabase.match_query( + **{label: value_list for label, value_list in zip(keys, values_list, strict=True)} + )) + ) + + +@pytest.mark.usefixtures("_test_server_fixture") +def test_queries_mix_platform_sensor(): + """Tests a mix of platform and sensor queries.""" + for n_plt, n_sns in zip([1, 1, 2, 3, 3], [1, 3, 2, 1, 3], strict=False): + assert query_results_are_correct( + ["platform", "sensor"], + [TestDatabase.unique_platform_names[:n_plt], TestDatabase.unique_sensors[:n_sns]] + ) + + +@pytest.mark.usefixtures("_test_server_fixture") +def test_queries_time(): + """Checks that a single time query works properly.""" + res = http_get("datetime").json() + time_min = datetime.fromisoformat(res["start_time"]["_min"]["_time"]) + time_max = datetime.fromisoformat(res["end_time"]["_max"]["_time"]) + + assert single_query_is_correct( + "time_min", + time_min + ) + + assert single_query_is_correct( + "time_max", + time_max + ) + + +def single_query_is_correct(key: str, value: str | datetime) -> bool: + """Checks if the given single query, denoted by ``key`` matches correctly against the ``value``.""" + return ( + Counter(http_get(f"queries?{key}={value}").json()) == + Counter(TestDatabase.match_query(**{key: value})) + ) diff --git a/trolldb/tests/tests_database/test_pipelines.py b/trolldb/tests/tests_database/test_pipelines.py index b993aff..df83b01 100644 --- a/trolldb/tests/tests_database/test_pipelines.py +++ b/trolldb/tests/tests_database/test_pipelines.py @@ -1,5 +1,5 @@ """Tests for the pipelines and applying comparison operations on them.""" -from trolldb.database.piplines import PipelineAttribute, PipelineBooleanDict, Pipelines +from trolldb.database.pipelines import PipelineAttribute, PipelineBooleanDict, Pipelines from trolldb.test_utils.common import compare_by_operator_name