diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0cda24d5..ae833749 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,18 +6,18 @@ jobs: publish-pypi: runs-on: ubuntu-latest steps: - - uses: actions/checkout@master - - name: Set up Python 3.8 - uses: actions/setup-python@v2.2.1 + - uses: actions/checkout@v4 + - name: Set up Python 3.13 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.13 - name: Install dependencies run: pip install -qU setuptools wheel twine - name: Generating distribution archives run: python setup.py sdist bdist_wheel - name: Publish distribution 馃摝 to PyPI if: startsWith(github.event.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@master + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.pypi_password }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6aff0262..c56ca2c7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,11 +6,11 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2.2.1 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.13 - name: Install dependencies run: make install-test - name: Lint @@ -20,11 +20,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2.2.1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -35,19 +35,20 @@ jobs: coverage: runs-on: ubuntu-latest steps: - - uses: actions/checkout@master + - uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v2.2.1 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.13 - name: Install dependencies run: make install-test - name: Generate coverage report run: pytest --cov-report=xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v5 with: file: ./coverage.xml flags: unittests name: codecov-umbrella + token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true diff --git a/Makefile b/Makefile index 22a49b21..7daf143b 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ SHELL := bash PATH := ./venv/bin:${PATH} -PYTHON = python3.8 +PYTHON = python3.13 PROJECT = agave isort = isort $(PROJECT) examples tests setup.py -black = black -S -l 79 --target-version py38 $(PROJECT) examples $(PROJECT)/lib/* tests setup.py +black = black -S -l 79 --target-version py313 $(PROJECT) tests setup.py examples .PHONY: all diff --git a/README.md b/README.md index 62fcfdc1..fcd02094 100644 --- a/README.md +++ b/README.md @@ -3,31 +3,136 @@ [![codecov](https://codecov.io/gh/cuenca-mx/agave/branch/main/graph/badge.svg)](https://codecov.io/gh/cuenca-mx/agave) [![PyPI](https://img.shields.io/pypi/v/agave.svg)](https://pypi.org/project/agave/) -Agave is a library that implement rest_api across the use of Blueprints based on Chalice Aws. +Agave is a library for building REST APIs using a Blueprint pattern, with support for both AWS Chalice and FastAPI frameworks. It simplifies the creation of JSON-based endpoints for querying, modifying, and creating resources. -this library allow send and receive JSON data to these endpoints to query, modify and create content. +## Installation -Install agave using pip: +Choose the installation option based on your framework: +### Chalice Installation + +```bash +pip install agave[chalice] +``` + +### FastAPI Installation + +```bash +pip install agave[fastapi] +``` + +### SQS task support: ```bash -pip install agave==0.0.2.dev0 +pip install agave[fastapi,tasks] ``` -You can use agave for blueprint like this: +## Usage + +### Chalice Example + +You can then create a REST API blueprint as follows: ```python +from agave.chalice import RestApiBlueprint + +app = RestApiBlueprint() + +@app.resource('/accounts') +class Account: + model = AccountModel + query_validator = AccountQuery + update_validator = AccountUpdateRequest + get_query_filter = generic_query -from agave.blueprints.rest_api import RestApiBlueprint + @staticmethod + @app.validate(AccountRequest) + def create(request: AccountRequest) -> Response: + account = AccountModel( + name=request.name, + user_id=app.current_user_id, + platform_id=app.current_platform_id, + ) + account.save() + return Response(account.to_dict(), status_code=201) + @staticmethod + def update( + account: AccountModel, request: AccountUpdateRequest + ) -> Response: + account.name = request.name + account.save() + return Response(account.to_dict(), status_code=200) + + @staticmethod + def delete(account: AccountModel) -> Response: + account.deactivated_at = dt.datetime.utcnow().replace(microsecond=0) + account.save() + return Response(account.to_dict(), status_code=200) ``` -agave include helpers for mongoengine, for example: +### FastAPI Example + ```python +from agave.fastapi import RestApiBlueprint + +app = RestApiBlueprint() + +@app.resource('/accounts') +class Account: + model = AccountModel + query_validator = AccountQuery + update_validator = AccountUpdateRequest + get_query_filter = generic_query + response_model = AccountResponse + + @staticmethod + async def create(request: AccountRequest) -> Response: + """This is the description for openapi""" + account = AccountModel( + name=request.name, + user_id=app.current_user_id, + platform_id=app.current_platform_id, + ) + await account.async_save() + return Response(content=account.to_dict(), status_code=201) + + @staticmethod + async def update( + account: AccountModel, + request: AccountUpdateRequest, + ) -> Response: + account.name = request.name + await account.async_save() + return Response(content=account.to_dict(), status_code=200) -from agave.models.helpers import (uuid_field, mongo_to_dict, EnumField, updated_at, list_field_to_dict) + @staticmethod + async def delete(account: AccountModel, _: Request) -> Response: + account.deactivated_at = dt.datetime.utcnow().replace(microsecond=0) + await account.async_save() + return Response(content=account.to_dict(), status_code=200) +``` + +### Async Tasks + +```python +from agave.tasks.sqs_tasks import task +QUEUE_URL = 'https://sqs.region.amazonaws.com/account/queue' +AWS_DEFAULT_REGION = 'us-east-1' +@task( + queue_url=QUEUE_URL, + region_name=AWS_DEFAULT_REGION, + visibility_timeout=30, + max_retries=10, +) +async def process_data(data: dict): + # Async task processing + return {'processed': data} ``` -Correr tests +## Running Tests + +Run the tests using the following command: + ```bash make test ``` diff --git a/agave/chalice/__init__.py b/agave/chalice/__init__.py new file mode 100644 index 00000000..ff6d0c91 --- /dev/null +++ b/agave/chalice/__init__.py @@ -0,0 +1,2 @@ +__all__ = ['RestApiBlueprint'] +from .rest_api import RestApiBlueprint diff --git a/agave/lib/__init__.py b/agave/chalice/models/__init__.py similarity index 100% rename from agave/lib/__init__.py rename to agave/chalice/models/__init__.py diff --git a/agave/models/helpers.py b/agave/chalice/models/helpers.py similarity index 78% rename from agave/models/helpers.py rename to agave/chalice/models/helpers.py index 8863730e..4f0192c9 100644 --- a/agave/models/helpers.py +++ b/agave/chalice/models/helpers.py @@ -2,13 +2,6 @@ from base64 import urlsafe_b64encode -def uuid_field(prefix: str = ''): - def base64_uuid_func() -> str: - return prefix + urlsafe_b64encode(uuid.uuid4().bytes).decode()[:-2] - - return base64_uuid_func - - # This function is used to generate an id composed of a # list of fields in alphabetical order, for example if we want # uuid_field_generic('AC', account_number='bla', user_id='ble') diff --git a/agave/blueprints/rest_api.py b/agave/chalice/rest_api.py similarity index 95% rename from agave/blueprints/rest_api.py rename to agave/chalice/rest_api.py index 688a4819..fdcbd285 100644 --- a/agave/blueprints/rest_api.py +++ b/agave/chalice/rest_api.py @@ -2,12 +2,19 @@ from typing import Any, Optional, Type, cast from urllib.parse import urlencode -from chalice import Blueprint, NotFoundError, Response +try: + from chalice import Blueprint, NotFoundError, Response +except ImportError: + raise ImportError( + "You must install agave with [chalice] option.\n" + "You can install it with: pip install agave[chalice]" + ) + from cuenca_validations.types import QueryParams from mongoengine import DoesNotExist, Q from pydantic import BaseModel, ValidationError -from .decorators import copy_attributes +from ..core.blueprints.decorators import copy_attributes class RestApiBlueprint(Blueprint): @@ -71,8 +78,9 @@ def retrieve_object(self, resource_class: Any, resource_id: str) -> Any: return data def validate(self, validation_type: Type[BaseModel]): - """This decorator validate the request body using a custom pydantyc model - If validation fails return a BadRequest response with details + """This decorator validate the request body using a + custom pydantyc model. If validation fails return a + BadRequest response with details @app.validate(MyPydanticModel) def my_method(request: MyPydanticModel): @@ -284,7 +292,7 @@ def _all(query: QueryParams, filters: Q): if wants_more and has_more: query.created_before = item_dicts[-1]['created_at'] path = self.current_request.context['resourcePath'] - params = query.dict() + params = query.model_dump() if self.user_id_filter_required(): params.pop('user_id') if self.platform_id_filter_required(): diff --git a/agave/lib/mongoengine/__init__.py b/agave/core/__init__.py similarity index 100% rename from agave/lib/mongoengine/__init__.py rename to agave/core/__init__.py diff --git a/examples/chalicelib/__init__.py b/agave/core/blueprints/__init__.py similarity index 100% rename from examples/chalicelib/__init__.py rename to agave/core/blueprints/__init__.py diff --git a/agave/blueprints/decorators.py b/agave/core/blueprints/decorators.py similarity index 87% rename from agave/blueprints/decorators.py rename to agave/core/blueprints/decorators.py index a862a06c..8a214dc9 100644 --- a/agave/blueprints/decorators.py +++ b/agave/core/blueprints/decorators.py @@ -16,7 +16,8 @@ def wrapper(func: Callable): return func for key, val in original_func.__dict__.items(): - setattr(func, key, val) + if not key.startswith('_'): + setattr(func, key, val) return func diff --git a/agave/core/exc.py b/agave/core/exc.py new file mode 100644 index 00000000..76dd511e --- /dev/null +++ b/agave/core/exc.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class AgaveError(Exception): + error: str + status_code: int + + +@dataclass +class BadRequestError(AgaveError): + status_code: int = 400 + + +@dataclass +class UnauthorizedError(AgaveError): + status_code: int = 401 + + +@dataclass +class ForbiddenError(AgaveError): + status_code: int = 403 + + +@dataclass +class NotFoundError(AgaveError): + status_code: int = 404 + + +@dataclass +class MethodNotAllowedError(AgaveError): + status_code: int = 405 + + +@dataclass +class ConflictError(AgaveError): + status_code: int = 409 + + +@dataclass +class UnprocessableEntity(AgaveError): + status_code: int = 422 + + +@dataclass +class TooManyRequests(AgaveError): + status_code: int = 429 + + +@dataclass +class AgaveViewError(AgaveError): + status_code: int = 500 + + +@dataclass +class ServiceUnavailableError(AgaveError): + status_code: int = 503 + + +@dataclass +class RetryTask(Exception): + countdown: Optional[int] = None diff --git a/agave/filters.py b/agave/core/filters.py similarity index 77% rename from agave/filters.py rename to agave/core/filters.py index b44978ed..f2d10669 100644 --- a/agave/filters.py +++ b/agave/core/filters.py @@ -2,7 +2,7 @@ from mongoengine import Q -def generic_query(query: QueryParams) -> Q: +def generic_query(query: QueryParams, excluded: list[str] = []) -> Q: filters = Q() if query.created_before: filters &= Q(created_at__lt=query.created_before) @@ -15,8 +15,9 @@ def generic_query(query: QueryParams) -> Q: 'limit', 'page_size', 'key', + *excluded, } - fields = query.dict(exclude=exclude_fields) + fields = query.model_dump(exclude=exclude_fields) if 'count' in fields: del fields['count'] return filters & Q(**fields) diff --git a/agave/blueprints/__init__.py b/agave/fastapi/__init__.py similarity index 100% rename from agave/blueprints/__init__.py rename to agave/fastapi/__init__.py index 96c4def2..7a8611d9 100644 --- a/agave/blueprints/__init__.py +++ b/agave/fastapi/__init__.py @@ -1,3 +1,3 @@ -__all__ = ['RestApiBlueprint'] - from .rest_api import RestApiBlueprint + +__all__ = ['RestApiBlueprint'] diff --git a/agave/fastapi/middlewares/__init__.py b/agave/fastapi/middlewares/__init__.py new file mode 100644 index 00000000..488eeb64 --- /dev/null +++ b/agave/fastapi/middlewares/__init__.py @@ -0,0 +1,5 @@ +from .error_handlers import AgaveErrorHandler + +__all__ = [ + 'AgaveErrorHandler', +] diff --git a/agave/fastapi/middlewares/error_handlers.py b/agave/fastapi/middlewares/error_handlers.py new file mode 100644 index 00000000..dc035443 --- /dev/null +++ b/agave/fastapi/middlewares/error_handlers.py @@ -0,0 +1,59 @@ +from cuenca_validations.errors import CuencaError +from fastapi import Request, Response +from fastapi.responses import JSONResponse +from fastapi.routing import APIRoute +from starlette.middleware.base import ( + BaseHTTPMiddleware, + RequestResponseEndpoint, +) +from starlette.routing import Match + +from ...core.exc import AgaveError, MethodNotAllowedError, NotFoundError + + +class AgaveErrorHandler(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + try: + request.scope['route_handler'] = get_current_route_handler(request) + return await call_next(request) + except CuencaError as exc: + return JSONResponse( + status_code=exc.status_code, + content=dict( + code=exc.code, + error=str(exc), + ), + ) + except AgaveError as exc: + return JSONResponse( + status_code=exc.status_code, content=dict(error=exc.error) + ) + + +def get_current_route_handler(request: Request) -> APIRoute: + """ + Helper method for getting the route handler of the current request. + + If there is not route handler it raises appropriate status code error + consistent with the `Route.__call__` behavior + https://github.com/encode/starlette/blob/5d768322d6d7adc31df54b1ad306f417e3da2c81/starlette/routing.py#L656-L666 + Args: + request: fastapi request object + + Returns: + APIRoute instance for the current request + """ + partial = None + for route in request.app.routes: + match, _ = route.matches(request.scope) + if match == Match.FULL: + return route + if match == Match.PARTIAL and partial is None: + partial = route + + if partial is not None: + raise MethodNotAllowedError('Method Not Allowed') + else: + raise NotFoundError('Not Found') diff --git a/agave/fastapi/rest_api.py b/agave/fastapi/rest_api.py new file mode 100644 index 00000000..3cbc2df8 --- /dev/null +++ b/agave/fastapi/rest_api.py @@ -0,0 +1,417 @@ +import mimetypes +from typing import Any, Optional +from urllib.parse import urlencode + +from cuenca_validations.types import QueryParams + +try: + from fastapi import APIRouter, BackgroundTasks, Depends, Request, status +except ImportError: + raise ImportError( + "You must install agave with [fastapi] option.\n" + "You can install it with: pip install agave[fastapi]" + ) + + +from fastapi.responses import JSONResponse as Response, StreamingResponse +from mongoengine import DoesNotExist, Q +from pydantic import BaseModel, Field, ValidationError +from starlette_context import context + +from ..core.blueprints.decorators import copy_attributes +from ..core.exc import NotFoundError, UnprocessableEntity + +SAMPLE_404 = { + "summary": "Not found item", + "value": {"error": "Not valid id"}, +} + + +class RestApiBlueprint(APIRouter): + @property + def current_user_id(self) -> str: + return context['user_id'] + + @property + def current_platform_id(self) -> str: + return context['platform_id'] + + def user_id_filter_required(self) -> bool: + return context['user_id_filter_required'] + + def platform_id_filter_required(self) -> bool: + return context['platform_id_filter_required'] + + def custom_filter_required(self, query_params: Any, model: Any) -> None: + """ + Overwrite this method in order to add new context + based on custom filter. + set de value of your filter ex query_params.wallet = self.wallet + """ + pass + + async def retrieve_object( + self, resource_class: Any, resource_id: str + ) -> Any: + resource_id = ( + self.current_user_id if resource_id == 'me' else resource_id + ) + query = Q(id=resource_id) + if self.platform_id_filter_required() and hasattr( + resource_class.model, 'platform_id' + ): + query = query & Q(platform_id=self.current_platform_id) + + if self.user_id_filter_required() and hasattr( + resource_class.model, 'user_id' + ): + query = query & Q(user_id=self.current_user_id) + + try: + data = await resource_class.model.objects.async_get(query) + except DoesNotExist: + raise NotFoundError('Not valid id') + return data + + def resource(self, path: str): + """Decorator to transform a class in FastApi REST endpoints + + @app.resource('/my_resource') + class Items(Resource): + model = MyMongoModel + response_model = MyPydanticModel (Resource Interface) + query_validator = MyPydanticModel + + def create(): ... + def delete(id): ... + def retrieve(id): ... + def get_query_filter(): ... + + This implementation create the following endpoints + + POST /my_resource + PATCH /my_resource + DELETE /my_resource/id + GET /my_resource/id + GET /my_resource + """ + + def wrapper_resource_class(cls): + """Wrapper for resource class + :param cls: Resoucre class + :return: + """ + response_model = Any + response_sample = {} + include_in_schema = getattr(cls, 'include_in_schema', True) + if hasattr(cls, 'response_model'): + response_model = cls.response_model + response_sample = response_model.schema().get('example') + + """ POST /resource + Create a FastApi endpoint using the method "create" + + OR using the method "upload" to enable POST using a + streaming multipart parser to receive files as form data. It + validates form data using `Resource.upload_validator`. + """ + if hasattr(cls, 'create'): + route = self.post( + path, + summary=f'{cls.__name__} - Create', + response_model=response_model, + status_code=status.HTTP_201_CREATED, + include_in_schema=include_in_schema, + ) + route(cls.create) + elif hasattr(cls, 'upload'): + + @self.post( + path, + summary=f'{cls.__name__} - Upload', + response_model=response_model, + include_in_schema=include_in_schema, + openapi_extra={ + "requestBody": { + "content": { + "form-data": { + "schema": cls.upload_validator.schema() + } + } + } + }, + ) + @copy_attributes(cls) + async def upload( + request: Request, background_tasks: BackgroundTasks + ): + form = await request.form() + try: + upload_params = cls.upload_validator(**form) + except ValidationError as exc: + return Response(content=exc.json(), status_code=400) + + return await cls.upload(upload_params, background_tasks) + + """ DELETE /resource/{id} + Use "delete" method (if exists) to create the FastApi endpoint + """ + error_404 = json_openapi(404, 'Item not found', [SAMPLE_404]) + if hasattr(cls, 'delete'): + + @self.delete( + path + '/{id}', + summary=f'{cls.__name__} - Delete', + response_model=response_model, + responses=error_404, + description=( + f'Use id param to delete the {cls.__name__} object' + ), + include_in_schema=include_in_schema, + ) + @copy_attributes(cls) + async def delete(id: str, request: Request): + obj = await self.retrieve_object(cls, id) + return await cls.delete(obj, request) + + """ PATCH /resource/{id} + Enable PATCH method if Resource.update method exist. It validates + body data using `Resource.update_validator` but update logic is + completely your responsibility. + """ + if hasattr(cls, 'update'): + + @self.patch( + path + '/{id}', + summary=f'{cls.__name__} - Update', + response_model=response_model, + responses=error_404, + description=( + f'Use id param to update the {cls.__name__} object' + ), + include_in_schema=include_in_schema, + ) + @copy_attributes(cls) + async def update( + id: str, + update_params: cls.update_validator, # type: ignore + request: Request, + ): + obj = await self.retrieve_object(cls, id) + try: + return await cls.update(obj, update_params, request) + except TypeError: + return await cls.update(obj, update_params) + + """ GET /resource/{id} + By default GET method only fetch object from DB. + If you need extra logic override "retrieve" or "download" methods + """ + + @self.get( + path + '/{id}', + summary=f'{cls.__name__} - Retrieve', + response_model=response_model, + responses=error_404, + description=( + f'Use id param to retrieve the {cls.__name__} object' + ), + include_in_schema=include_in_schema, + ) + @copy_attributes(cls) + async def retrieve(id: str, request: Request): + """GET /resource/{id} + :param id: Object Id + :return: Model object + + If exists "retrieve" method return the result of that, else + use "id" param to retrieve the object of type "model" defined + in the decorated class. + + The most of times this implementation is enough and is not + necessary define a custom "retrieve" method + """ + obj = await self.retrieve_object(cls, id) + + # This case is when the return is not an application/$ + # but can be some type of file such as image, xml, zip or pdf + if hasattr(cls, 'download'): + file = await cls.download(obj) + mimetype = request.headers['accept'] + extension = mimetypes.guess_extension(mimetype) + filename = f'{cls.model._class_name}.{extension}' + result = StreamingResponse( + file, + media_type=mimetype, + headers={ + 'Content-Disposition': ( + 'attachment; ' f'filename={filename}' + ) + }, + ) + elif hasattr(cls, 'retrieve'): + result = await cls.retrieve(obj) + else: + result = obj.to_dict() + + return result + + """ GET /resource?param=value + Use GET method to fetch and count filtered objects + using query params. + To Enable queries you have to define next fields + in decorated class + + query_validator: Pydantic model to validate the params. + get_query_filter: Method to provide the way that + the params are used to filter data. + """ + + if not hasattr(cls, 'query_validator') or not hasattr( + cls, 'get_query_filter' + ): + return cls + + query_description = ( + f'Make queries in resource {cls.__name__} and filter the ' + f'result using query parameters. \n' + f'The items are paginated, to iterate over them use the ' + f'`next_page_uri` included in response. \n' # noqa: W604 + f'If you need only a counter not the data send value `true` ' + f'in `count` param.' + ) + + # Build dynamically types for query response + class QueryResponse(BaseModel): + items: Optional[list[response_model]] = Field( + [], + description=( + f'List of {cls.__name__} that match with query ' + f'filters' + ), + ) + next_page_uri: Optional[str] = Field( + None, description='URL to fetch the next page of results' + ) + count: Optional[int] = Field( + None, + description=( + f'Counter of {cls.__name__} objects that match with ' + f'query filters. \n' + f'If you need only a counter not the data send value ' + f'`true` in `count` param.' # noqa: W604 + ), + ) + + QueryResponse.__name__ = f'QueryResponse{cls.__name__}' + + examples = [ + # If param "count" is False return the list of items + { + 'summary': 'Query objects', + 'value': { + 'items': [response_sample], + 'next_page_uri': f'{path}?param1=value1¶m2=value2', + }, + }, + # If param "count" is True return a counter + { + 'summary': 'Count objects', + 'description': 'Sending `true` value in `count` param', + 'value': {'count': 1}, + }, + ] + + def validate_params(request: Request): + try: + return cls.query_validator(**request.query_params) + except ValidationError as e: + raise UnprocessableEntity(e.json()) + + @self.get( + path, + summary=f'{cls.__name__} - Query', + response_model=QueryResponse, + description=query_description, + responses=json_openapi(200, 'Successful Response', examples), + include_in_schema=include_in_schema, + ) + @copy_attributes(cls) + async def query( + query_params: cls.query_validator = Depends( # type: ignore + validate_params + ), + ): + """GET /resource""" + if self.platform_id_filter_required() and hasattr( + cls.model, 'platform_id' + ): + query_params.platform_id = self.current_platform_id + + if self.user_id_filter_required() and hasattr( + cls.model, 'user_id' + ): + query_params.user_id = self.current_user_id + # Call for custom filter implemented in overwritemethod + self.custom_filter_required(query_params, cls.model) + + filters = cls.get_query_filter(query_params) + if query_params.count: + result = await _count(filters) + elif hasattr(cls, 'query'): + result = await cls.query( + await _all(query_params, filters, path) + ) + else: + result = await _all(query_params, filters, path) + return result + + async def _count(filters: Q): + count = await cls.model.objects.filter(filters).async_count() + return dict(count=count) + + async def _all(query: QueryParams, filters: Q, resource_path: str): + if query.limit: + limit = min(query.limit, query.page_size) + query.limit = max(0, query.limit - limit) + else: + limit = query.page_size + query_set = ( + cls.model.objects.order_by("-created_at") + .filter(filters) + .limit(limit) + ) + items = await query_set.async_to_list() + item_dicts = [i.to_dict() for i in items] + + has_more: Optional[bool] = None + if wants_more := query.limit is None or query.limit > 0: + # only perform this query if it's necessary + has_more = ( + await query_set.limit(limit + 1).async_count() > limit + ) + + next_page_uri: Optional[str] = None + if wants_more and has_more: + query.created_before = item_dicts[-1]['created_at'] + params = query.model_dump() + if self.user_id_filter_required(): + params.pop('user_id') + if self.platform_id_filter_required(): + params.pop('platform_id') + next_page_uri = f'{resource_path}?{urlencode(params)}' + return dict(items=item_dicts, next_page_uri=next_page_uri) + + return cls + + return wrapper_resource_class + + +def json_openapi(code: int, description, samples: list[dict]) -> dict: + examples = {f'example_{i}': ex for i, ex in enumerate(samples)} + return { + code: { + 'description': description, + 'content': {'application/json': {'examples': examples}}, + }, + } diff --git a/agave/lib/mongoengine/enum_field.py b/agave/lib/mongoengine/enum_field.py deleted file mode 100644 index 84374f29..00000000 --- a/agave/lib/mongoengine/enum_field.py +++ /dev/null @@ -1,41 +0,0 @@ -from enum import Enum -from typing import Type - -from mongoengine.base import BaseField - - -class EnumField(BaseField): - """ - https://github.com/MongoEngine/extras-mongoengine/blob/master/ - extras_mongoengine/fields.py - A class to register Enum type (from the package enum34) into mongo - :param choices: must be of :class:`enum.Enum`: type - and will be used as possible choices - """ - - def __init__(self, enum: Type[Enum], *args, **kwargs): - self.enum = enum - kwargs['choices'] = [choice for choice in enum] - super(EnumField, self).__init__(*args, **kwargs) - - def __get_value(self, enum: Enum) -> str: - return enum.value if hasattr(enum, 'value') else enum - - def to_python(self, value: Enum) -> Enum: # pragma: no cover - return self.enum(super(EnumField, self).to_python(value)) - - def to_mongo(self, value: Enum) -> str: - return self.__get_value(value) - - def prepare_query_value(self, op, value: Enum) -> str: - return super(EnumField, self).prepare_query_value( # pragma: no cover - op, self.__get_value(value) - ) - - def validate(self, value: Enum) -> Enum: - return super(EnumField, self).validate(self.__get_value(value)) - - def _validate(self, value: Enum, **kwargs) -> Enum: - return super(EnumField, self)._validate( - self.enum(self.__get_value(value)), **kwargs - ) diff --git a/agave/lib/mongoengine/event_handlers.py b/agave/lib/mongoengine/event_handlers.py deleted file mode 100644 index 83084c41..00000000 --- a/agave/lib/mongoengine/event_handlers.py +++ /dev/null @@ -1,28 +0,0 @@ -import datetime as dt -from typing import Any - -from blinker import NamedSignal -from mongoengine import signals - - -def handler(event: NamedSignal): - """ - http://docs.mongoengine.org/guide/signals.html?highlight=update - Signal decorator to allow use of callback functions as class - decorators - """ - - def decorator(fn: Any): - def apply(cls): - event.connect(fn, sender=cls) - return cls - - fn.apply = apply - return fn - - return decorator - - -@handler(signals.pre_save) -def updated_at(_, document): - document.updated_at = dt.datetime.utcnow() diff --git a/agave/lib/mongoengine/model_helpers.py b/agave/lib/mongoengine/model_helpers.py deleted file mode 100644 index ca3cd27f..00000000 --- a/agave/lib/mongoengine/model_helpers.py +++ /dev/null @@ -1,116 +0,0 @@ -# mypy: ignore-errors -from enum import Enum - -from bson import DBRef -from mongoengine import ( - BooleanField, - ComplexDateTimeField, - DateTimeField, - DecimalField, - DictField, - Document, - EmbeddedDocument, - EmbeddedDocumentField, - FloatField, - GenericLazyReferenceField, - IntField, - LazyReferenceField, - ListField, -) - -from .enum_field import EnumField - - -def mongo_to_dict(obj, exclude_fields: list = None) -> dict: - """ - from: https://gist.github.com/jason-w/4969476 - """ - return_data = {} - - if obj is None: - return return_data - - if isinstance(obj, Document): - return_data['id'] = str(obj.id) - - if exclude_fields is None: - exclude_fields = [] - - for field_name in obj._fields: - - if field_name in exclude_fields: - continue - - if field_name == 'id': - continue - data = obj._data[field_name] - if isinstance(obj._fields[field_name], ListField): - field_name = ( - f'{field_name}_uris' - if isinstance( - obj._fields[field_name].field, LazyReferenceField - ) - else field_name - ) - return_data[field_name] = list_field_to_dict(data) - elif isinstance(obj._fields[field_name], EmbeddedDocumentField): - return_data[field_name] = mongo_to_dict(data, []) - elif isinstance(obj._fields[field_name], DictField): - return_data[field_name] = data - elif isinstance(obj._fields[field_name], EnumField): - return_data[field_name] = data.value if data else None - elif isinstance(obj._fields[field_name], LazyReferenceField): - return_data[f'{field_name}_uri'] = ( - f'/{data._DBRef__collection}/{data.id}' if data else None - ) - elif isinstance(obj._fields[field_name], GenericLazyReferenceField): - return_data[f'{field_name}_uri'] = ( - f'/{data["_ref"]._DBRef__collection}/{data["_ref"].id}' - if data - else None - ) - else: - return_data[field_name] = mongo_to_python_type( - obj._fields[field_name], data - ) - - return return_data - - -def list_field_to_dict(list_field: list) -> list: - return_data = [] - - for item in list_field: - if isinstance(item, EmbeddedDocument): - return_data.append(mongo_to_dict(item)) - elif isinstance(item, Enum): - return_data.append(item.value) - elif isinstance(item, DBRef): # pragma: no cover - return_data.append(f'/{item._DBRef__collection}/{item.id}') - else: - return_data.append(mongo_to_python_type(item, item)) - - return return_data - - -def mongo_to_python_type(field, data): - rv = None - field_type = type(field) - if data is None: - rv = None - elif field_type is DateTimeField: - rv = data.isoformat() - elif field_type is ComplexDateTimeField: - rv = field.to_python(data).isoformat() - elif rv is FloatField: # pragma: no cover - rv = float(data) - elif field_type is IntField: - rv = int(data) - elif field_type is BooleanField: - rv = bool(data) - elif field_type is DecimalField: - rv = data - else: - rv = str(data) - - return rv diff --git a/agave/models/__init__.py b/agave/models/__init__.py deleted file mode 100644 index 5f4f48a7..00000000 --- a/agave/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -__all__ = ['BaseModel'] - -from .base import BaseModel diff --git a/agave/models/base.py b/agave/models/base.py deleted file mode 100644 index 4a370ba3..00000000 --- a/agave/models/base.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import ClassVar, Dict - -from ..lib.mongoengine.model_helpers import mongo_to_dict - - -class BaseModel: - _excluded: ClassVar = [] - _hidden: ClassVar = [] - - def __init__(self, *args, **values): - return super().__init__(*args, **values) - - def to_dict(self) -> Dict: - private_fields = [f for f in dir(self) if f.startswith('_')] - excluded = self._excluded + private_fields - mongo_dict: dict = mongo_to_dict(self, excluded) - - for field in self._hidden: - mongo_dict[field] = '********' - return mongo_dict - - def __repr__(self) -> str: - return str(self.to_dict()) # pragma: no cover diff --git a/tests/lib/__init__.py b/agave/tasks/__init__.py similarity index 100% rename from tests/lib/__init__.py rename to agave/tasks/__init__.py diff --git a/agave/tasks/sqs_celery_client.py b/agave/tasks/sqs_celery_client.py new file mode 100644 index 00000000..33f02bb5 --- /dev/null +++ b/agave/tasks/sqs_celery_client.py @@ -0,0 +1,75 @@ +import asyncio +import json +from base64 import b64encode +from dataclasses import dataclass +from typing import Iterable, Optional +from uuid import uuid4 + +from agave.tasks.sqs_client import SqsClient + + +def _build_celery_message( + task_name: str, args_: Iterable, kwargs_: dict +) -> str: + task_id = str(uuid4()) + # la definici贸n de esta plantila se encuentra en: + # docs.celeryproject.org/en/stable/internals/protocol.html#definition + message = dict( + properties=dict( + correlation_id=task_id, + content_type='application/json', + content_encoding='utf-8', + body_encoding='base64', + delivery_info=dict(exchange='', routing_key='celery'), + ), + headers=dict( + lang='py', + task=task_name, + id=task_id, + root_id=task_id, + parent_id=None, + group=None, + ), + body=_b64_encode( + json.dumps( + ( + args_, + kwargs_, + dict( + callbacks=None, errbacks=None, chain=None, chord=None + ), + ) + ) + ), + ) + message['content-encoding'] = 'utf-8' + message['content-type'] = 'application/json' + + encoded = _b64_encode(json.dumps(message)) + return encoded + + +def _b64_encode(value: str) -> str: + encoded = b64encode(bytes(value, 'utf-8')) + return encoded.decode('utf-8') + + +@dataclass +class SqsCeleryClient(SqsClient): + async def send_task( + self, + name: str, + args: Optional[Iterable] = None, + kwargs: Optional[dict] = None, + ) -> None: + celery_message = _build_celery_message(name, args or (), kwargs or {}) + await super().send_message(celery_message) + + def send_background_task( + self, + name: str, + args: Optional[Iterable] = None, + kwargs: Optional[dict] = None, + ) -> asyncio.Task: + celery_message = _build_celery_message(name, args or (), kwargs or {}) + return super().send_message_async(celery_message) diff --git a/agave/tasks/sqs_client.py b/agave/tasks/sqs_client.py new file mode 100644 index 00000000..cb20d2fe --- /dev/null +++ b/agave/tasks/sqs_client.py @@ -0,0 +1,63 @@ +import asyncio +import json +from dataclasses import dataclass, field +from typing import Optional, Union +from uuid import uuid4 + +try: + from aiobotocore.session import get_session + from types_aiobotocore_sqs import SQSClient +except ImportError: + raise ImportError( + "You must install agave with [fastapi, tasks] option.\n" + "You can install it with: pip install agave[fastapi, tasks]" + ) + + +@dataclass +class SqsClient: + queue_url: str + region_name: str + _sqs: SQSClient = field(init=False) + _background_tasks: set = field(init=False) + + @property + def background_tasks(self) -> set: + return self._background_tasks + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def start(self): + session = get_session() + context = session.create_client('sqs', self.region_name) + self._background_tasks = set() + self._sqs = await context.__aenter__() + + async def close(self): + await self._sqs.__aexit__(None, None, None) + + async def send_message( + self, + data: Union[str, dict], + message_group_id: Optional[str] = None, + ) -> None: + await self._sqs.send_message( + QueueUrl=self.queue_url, + MessageBody=data if type(data) is str else json.dumps(data), + MessageGroupId=message_group_id or str(uuid4()), + ) + + def send_message_async( + self, + data: Union[str, dict], + message_group_id: Optional[str] = None, + ) -> asyncio.Task: + task = asyncio.create_task(self.send_message(data, message_group_id)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task diff --git a/agave/tasks/sqs_tasks.py b/agave/tasks/sqs_tasks.py new file mode 100644 index 00000000..103e4a1d --- /dev/null +++ b/agave/tasks/sqs_tasks.py @@ -0,0 +1,153 @@ +import asyncio +import json +import os +from functools import wraps +from itertools import count +from json import JSONDecodeError +from typing import AsyncGenerator, Callable, Coroutine + +from aiobotocore.httpsession import HTTPClientError +from aiobotocore.session import get_session +from pydantic import validate_call + +from ..core.exc import RetryTask + +AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION', '') + +BACKGROUND_TASKS = set() + + +async def run_task( + task_func: Callable, + body: dict, + sqs, + queue_url: str, + receipt_handle: str, + message_receive_count: int, + max_retries: int, +) -> None: + delete_message = True + try: + await task_func(body) + except RetryTask as retry: + delete_message = message_receive_count >= max_retries + 1 + if not delete_message and retry.countdown and retry.countdown > 0: + await sqs.change_message_visibility( + QueueUrl=queue_url, + ReceiptHandle=receipt_handle, + VisibilityTimeout=retry.countdown, + ) + finally: + if delete_message: + await sqs.delete_message( + QueueUrl=queue_url, + ReceiptHandle=receipt_handle, + ) + + +async def message_consumer( + queue_url: str, + wait_time_seconds: int, + visibility_timeout: int, + can_read: asyncio.Event, + sqs, +) -> AsyncGenerator: + for _ in count(): + await can_read.wait() + try: + response = await sqs.receive_message( + QueueUrl=queue_url, + WaitTimeSeconds=wait_time_seconds, + VisibilityTimeout=visibility_timeout, + AttributeNames=['ApproximateReceiveCount'], + ) + messages = response['Messages'] + except KeyError: + continue + except HTTPClientError: + await asyncio.sleep(1) + continue + for message in messages: + yield message + + +async def get_running_fast_agave_tasks(): + return [ + t + for t in asyncio.all_tasks() + if t.get_name().startswith('fast-agave-task') + ] + + +def task( + queue_url: str, + region_name: str = AWS_DEFAULT_REGION, + wait_time_seconds: int = 15, + visibility_timeout: int = 3600, + max_retries: int = 1, + max_concurrent_tasks: int = 5, +): + def task_builder(task_func: Callable): + @wraps(task_func) + async def start_task(*args, **kwargs) -> None: + can_read = asyncio.Event() + concurrency_semaphore = asyncio.Semaphore(max_concurrent_tasks) + can_read.set() + + async def concurrency_controller(coro: Coroutine) -> None: + async with concurrency_semaphore: + if concurrency_semaphore.locked(): + can_read.clear() + + try: + await coro + finally: + can_read.set() + + session = get_session() + + task_with_validators = validate_call(task_func) + + async with session.create_client('sqs', region_name) as sqs: + async for message in message_consumer( + queue_url, + wait_time_seconds, + visibility_timeout, + can_read, + sqs, + ): + try: + body = json.loads(message['Body']) + except JSONDecodeError: + continue + + message_receive_count = int( + message['Attributes']['ApproximateReceiveCount'] + ) + bg_task = asyncio.create_task( + concurrency_controller( + run_task( + task_with_validators, + body, + sqs, + queue_url, + message['ReceiptHandle'], + message_receive_count, + max_retries, + ), + ), + name='fast-agave-task', + ) + BACKGROUND_TASKS.add(bg_task) + bg_task.add_done_callback(BACKGROUND_TASKS.discard) + + # Espera a que terminen todos los tasks pendientes creados por + # `asyncio.create_task`. De esta forma los tasks + # podr谩n borrar el mensaje del queue usando la misma instancia + # del cliente de SQS + running_tasks = await get_running_fast_agave_tasks() + await asyncio.gather(*running_tasks) + + return start_task + + return task_builder diff --git a/agave/version.py b/agave/version.py index 020ed73d..1f356cc5 100644 --- a/agave/version.py +++ b/agave/version.py @@ -1 +1 @@ -__version__ = '0.2.2' +__version__ = '1.0.0' diff --git a/examples/__init__.py b/examples/chalice/__init__.py similarity index 100% rename from examples/__init__.py rename to examples/chalice/__init__.py diff --git a/examples/app.py b/examples/chalice/app.py similarity index 61% rename from examples/app.py rename to examples/chalice/app.py index e40cf3b7..f7866711 100644 --- a/examples/app.py +++ b/examples/chalice/app.py @@ -1,9 +1,10 @@ +import mongomock as mongomock from chalice import Chalice from mongoengine import connect -from .chalicelib.resources import app as resources +from .resources import app as resources -DATABASE_URI = 'mongomock://localhost:27017/db' +DATABASE_URI = 'mongodb://localhost:27017/db' app = Chalice(app_name='test_app') app.register_blueprint(resources) @@ -11,7 +12,10 @@ app.api.binary_types.append('application/pdf') app.api.binary_types.append('application/xml') -connect(host=DATABASE_URI) +connect( + host=DATABASE_URI, + mongo_client_class=mongomock.MongoClient, +) @app.route('/') diff --git a/examples/chalicelib/blueprints/__init__.py b/examples/chalice/blueprints/__init__.py similarity index 71% rename from examples/chalicelib/blueprints/__init__.py rename to examples/chalice/blueprints/__init__.py index c376a4a4..7aa0dde7 100644 --- a/examples/chalicelib/blueprints/__init__.py +++ b/examples/chalice/blueprints/__init__.py @@ -1,9 +1,8 @@ __all__ = ['AuthedRestApiBlueprint'] -from agave.blueprints import RestApiBlueprint +from agave.chalice import RestApiBlueprint from .authed import AuthedBlueprint -class AuthedRestApiBlueprint(AuthedBlueprint, RestApiBlueprint): - ... +class AuthedRestApiBlueprint(AuthedBlueprint, RestApiBlueprint): ... diff --git a/examples/chalicelib/blueprints/authed.py b/examples/chalice/blueprints/authed.py similarity index 100% rename from examples/chalicelib/blueprints/authed.py rename to examples/chalice/blueprints/authed.py diff --git a/examples/chalicelib/resources/__init__.py b/examples/chalice/resources/__init__.py similarity index 100% rename from examples/chalicelib/resources/__init__.py rename to examples/chalice/resources/__init__.py diff --git a/examples/chalicelib/resources/accounts.py b/examples/chalice/resources/accounts.py similarity index 86% rename from examples/chalicelib/resources/accounts.py rename to examples/chalice/resources/accounts.py index cdbb11c1..66a49d3a 100644 --- a/examples/chalicelib/resources/accounts.py +++ b/examples/chalice/resources/accounts.py @@ -2,10 +2,10 @@ from chalice import Response -from agave.filters import generic_query +from agave.core.filters import generic_query -from ..models import Account as AccountModel -from ..validators import AccountQuery, AccountRequest, AccountUpdateRequest +from ...models import Account as AccountModel +from ...validators import AccountQuery, AccountRequest, AccountUpdateRequest from .base import app diff --git a/examples/chalicelib/resources/base.py b/examples/chalice/resources/base.py similarity index 75% rename from examples/chalicelib/resources/base.py rename to examples/chalice/resources/base.py index 929b6f4d..482ab7d0 100644 --- a/examples/chalicelib/resources/base.py +++ b/examples/chalice/resources/base.py @@ -1,10 +1,8 @@ -from typing import Dict - from ..blueprints import AuthedRestApiBlueprint app = AuthedRestApiBlueprint(__name__) @app.get('/healthy_auth') -def health_auth_check() -> Dict: +def health_auth_check() -> dict: return dict(greeting="I'm authenticated and healthy !!!") diff --git a/examples/chalicelib/resources/billers.py b/examples/chalice/resources/billers.py similarity index 55% rename from examples/chalicelib/resources/billers.py rename to examples/chalice/resources/billers.py index 03713fe2..e5526496 100644 --- a/examples/chalicelib/resources/billers.py +++ b/examples/chalice/resources/billers.py @@ -1,7 +1,7 @@ -from agave.filters import generic_query +from agave.core.filters import generic_query -from ..models import Biller as BillerModel -from ..validators import BillerQuery +from ...models import Biller as BillerModel +from ...validators import BillerQuery from .base import app diff --git a/examples/chalicelib/resources/cards.py b/examples/chalice/resources/cards.py similarity index 73% rename from examples/chalicelib/resources/cards.py rename to examples/chalice/resources/cards.py index 3cc705fa..280c574a 100644 --- a/examples/chalicelib/resources/cards.py +++ b/examples/chalice/resources/cards.py @@ -1,11 +1,9 @@ -from typing import Dict - from chalice import Response -from agave.filters import generic_query +from agave.core.filters import generic_query -from ..models import Card as CardModel -from ..validators import CardQuery +from ...models import Card as CardModel +from ...validators import CardQuery from .base import app @@ -22,7 +20,7 @@ def retrieve(card: CardModel) -> Response: return Response(data) @staticmethod - def query(response: Dict): + def query(response: dict) -> dict: for item in response['items']: item['number'] = '*' * 16 return response diff --git a/examples/chalicelib/resources/files.py b/examples/chalice/resources/files.py similarity index 76% rename from examples/chalicelib/resources/files.py rename to examples/chalice/resources/files.py index 26b4202a..6489f850 100644 --- a/examples/chalicelib/resources/files.py +++ b/examples/chalice/resources/files.py @@ -4,10 +4,10 @@ from chalice import NotFoundError, Response from mongoengine import DoesNotExist -from agave.filters import generic_query +from agave.core.filters import generic_query -from ..models import File as FileModel -from ..validators import FileQuery +from ...models import File as FileModel +from ...validators import FileQuery from .base import app diff --git a/examples/chalicelib/resources/transactions.py b/examples/chalice/resources/transactions.py similarity index 53% rename from examples/chalicelib/resources/transactions.py rename to examples/chalice/resources/transactions.py index f940510f..11e74202 100644 --- a/examples/chalicelib/resources/transactions.py +++ b/examples/chalice/resources/transactions.py @@ -1,7 +1,7 @@ -from agave.filters import generic_query +from agave.core.filters import generic_query -from ..models.transactions import Transaction as TransactionModel -from ..validators import TransactionQuery +from ...models.transactions import Transaction as TransactionModel +from ...validators import TransactionQuery from .base import app diff --git a/examples/chalicelib/resources/users.py b/examples/chalice/resources/users.py similarity index 55% rename from examples/chalicelib/resources/users.py rename to examples/chalice/resources/users.py index 1c32e80b..dba29d20 100644 --- a/examples/chalicelib/resources/users.py +++ b/examples/chalice/resources/users.py @@ -1,7 +1,7 @@ -from agave.filters import generic_query +from agave.core.filters import generic_query -from ..models import User as UserModel -from ..validators import UserQuery +from ...models import User as UserModel +from ...validators import UserQuery from .base import app diff --git a/examples/chalicelib/models/billers.py b/examples/chalicelib/models/billers.py deleted file mode 100644 index ded9bc08..00000000 --- a/examples/chalicelib/models/billers.py +++ /dev/null @@ -1,12 +0,0 @@ -import datetime as dt - -from mongoengine import DateTimeField, Document, StringField - -from agave.models import BaseModel -from agave.models.helpers import uuid_field - - -class Biller(BaseModel, Document): - id = StringField(primary_key=True, default=uuid_field('BL')) - created_at = DateTimeField(default=dt.datetime.utcnow) - name = StringField(required=True) diff --git a/examples/chalicelib/models/cards.py b/examples/chalicelib/models/cards.py deleted file mode 100644 index 3ff73156..00000000 --- a/examples/chalicelib/models/cards.py +++ /dev/null @@ -1,11 +0,0 @@ -from mongoengine import DateTimeField, Document, StringField - -from agave.models import BaseModel -from agave.models.helpers import uuid_field - - -class Card(BaseModel, Document): - id = StringField(primary_key=True, default=uuid_field('CA')) - number = StringField(required=True) - user_id = StringField(required=True) - created_at = DateTimeField() diff --git a/examples/chalicelib/models/files.py b/examples/chalicelib/models/files.py deleted file mode 100644 index 176791f9..00000000 --- a/examples/chalicelib/models/files.py +++ /dev/null @@ -1,10 +0,0 @@ -from mongoengine import Document, StringField - -from agave.models import BaseModel -from agave.models.helpers import uuid_field - - -class File(BaseModel, Document): - id = StringField(primary_key=True, default=uuid_field('TR')) - user_id = StringField(required=True) - name = StringField(required=True) diff --git a/examples/chalicelib/models/transactions.py b/examples/chalicelib/models/transactions.py deleted file mode 100644 index dcde86d4..00000000 --- a/examples/chalicelib/models/transactions.py +++ /dev/null @@ -1,10 +0,0 @@ -from mongoengine import Document, FloatField, StringField - -from agave.models import BaseModel -from agave.models.helpers import uuid_field - - -class Transaction(BaseModel, Document): - id = StringField(primary_key=True, default=uuid_field('TR')) - user_id = StringField(required=True) - amount = FloatField(required=True) diff --git a/examples/chalicelib/models/users.py b/examples/chalicelib/models/users.py deleted file mode 100644 index f0305069..00000000 --- a/examples/chalicelib/models/users.py +++ /dev/null @@ -1,13 +0,0 @@ -import datetime as dt - -from mongoengine import DateTimeField, Document, StringField - -from agave.models import BaseModel -from agave.models.helpers import uuid_field - - -class User(BaseModel, Document): - id = StringField(primary_key=True, default=uuid_field('US')) - created_at = DateTimeField(default=dt.datetime.utcnow) - name = StringField(required=True) - platform_id = StringField(required=True) diff --git a/examples/chalicelib/validators.py b/examples/chalicelib/validators.py deleted file mode 100644 index ab444d31..00000000 --- a/examples/chalicelib/validators.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Optional - -from cuenca_validations.types import QueryParams -from pydantic import BaseModel - - -class AccountQuery(QueryParams): - name: Optional[str] = None - user_id: Optional[str] = None - platform_id: Optional[str] = None - active: Optional[bool] = None - - -class TransactionQuery(QueryParams): - user_id: Optional[str] = None - - -class BillerQuery(QueryParams): - name: str - - -class UserQuery(QueryParams): - platform_id: str - - -class AccountRequest(BaseModel): - name: str - - -class AccountUpdateRequest(BaseModel): - name: str - - -class FileQuery(QueryParams): - user_id: Optional[str] = None - - -class CardQuery(QueryParams): - number: Optional[str] = None diff --git a/tests/models/__init__.py b/examples/fastapi/__init__.py similarity index 100% rename from tests/models/__init__.py rename to examples/fastapi/__init__.py diff --git a/examples/fastapi/app.py b/examples/fastapi/app.py new file mode 100644 index 00000000..cc18c1b0 --- /dev/null +++ b/examples/fastapi/app.py @@ -0,0 +1,35 @@ +import asyncio + +import mongomock as mongomock +from fastapi import FastAPI +from mongoengine import connect + +from agave.fastapi.middlewares import AgaveErrorHandler + +from ..tasks.task_example import dummy_task, task_validator +from .middlewares import AuthedMiddleware +from .resources import app as resources + +connect( + host='mongodb://localhost:27017/db', + mongo_client_class=mongomock.MongoClient, +) +app = FastAPI(title='example') +app.include_router(resources) + + +app.add_middleware(AuthedMiddleware) +app.add_middleware(AgaveErrorHandler) + + +@app.get('/') +async def iam_healty() -> dict: + return dict(greeting="I'm healthy!!!") + + +@app.on_event('startup') +async def on_startup() -> None: # pragma: no cover + # Inicializa el task que recibe mensajes + # provenientes de SQS + asyncio.create_task(dummy_task()) + asyncio.create_task(task_validator()) diff --git a/examples/fastapi/blueprints/custom_query_blueprint.py b/examples/fastapi/blueprints/custom_query_blueprint.py new file mode 100644 index 00000000..2d30ee09 --- /dev/null +++ b/examples/fastapi/blueprints/custom_query_blueprint.py @@ -0,0 +1,24 @@ +from typing import Any + +from starlette_context import context + +from agave.fastapi import RestApiBlueprint + + +class CustomQueryBlueprint(RestApiBlueprint): + @property + def custom(self) -> str: + return context['custom'] + + def property_filter_required(self) -> bool: + return context.get('custom_filter_required') + + def custom_filter_required(self, query_params: Any, model: any) -> None: + if self.property_filter_required() and hasattr(model, 'custom'): + query_params.custom = self.custom + + def user_id_filter_required(self) -> bool: + return False + + def platform_id_filter_required(self) -> bool: + return False diff --git a/examples/fastapi/middlewares/__init__.py b/examples/fastapi/middlewares/__init__.py new file mode 100644 index 00000000..dc51a537 --- /dev/null +++ b/examples/fastapi/middlewares/__init__.py @@ -0,0 +1,3 @@ +__all__ = ['AuthedMiddleware'] + +from .authed import AuthedMiddleware diff --git a/examples/fastapi/middlewares/authed.py b/examples/fastapi/middlewares/authed.py new file mode 100644 index 00000000..9688df9c --- /dev/null +++ b/examples/fastapi/middlewares/authed.py @@ -0,0 +1,64 @@ +from fastapi import Request +from starlette.middleware.base import RequestResponseEndpoint +from starlette.responses import Response +from starlette_context import _request_scope_context_storage +from starlette_context.middleware import ContextMiddleware + +from ...config import TEST_DEFAULT_PLATFORM_ID, TEST_DEFAULT_USER_ID + + +class AuthedMiddleware(ContextMiddleware): + def __init__( + self, app, plugins=None, default_error_response=None, *args, **kwargs + ): + super().__init__( + app=app, + plugins=plugins, + default_error_response=default_error_response, + *args, + **kwargs, + ) + + def required_user_id(self) -> bool: + """ + Example method so we can easily mock it in tests environment + :return: + """ + return False + + def required_platform_id(self) -> bool: + """ + Example method so we can easily mock it in tests environment + :return: + """ + return False + + async def authenticate(self): + self.token = _request_scope_context_storage.set( + dict( + user_id=TEST_DEFAULT_USER_ID, + platform_id=TEST_DEFAULT_PLATFORM_ID, + ) + ) + + async def authorize(self): + context = _request_scope_context_storage.get() + context['user_id_filter_required'] = self.required_user_id() + context['platform_id_filter_required'] = self.required_platform_id() + self.token = _request_scope_context_storage.set(context) + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + try: + # Authentication and authorization goes here! + await self.authenticate() + await self.authorize() + response = await call_next(request) + for plugin in self.plugins: + await plugin.enrich_response(response) + + finally: + _request_scope_context_storage.reset(self.token) + + return response diff --git a/examples/chalicelib/models/__init__.py b/examples/fastapi/resources/__init__.py similarity index 56% rename from examples/chalicelib/models/__init__.py rename to examples/fastapi/resources/__init__.py index 2435a9f1..65390250 100644 --- a/examples/chalicelib/models/__init__.py +++ b/examples/fastapi/resources/__init__.py @@ -1,6 +1,9 @@ -__all__ = ['Account', 'Biller', 'Card', 'Transaction', 'File', 'User'] +__all__ = ['Account', 'app', 'Biller', 'Card', 'File', 'Transaction', 'ApiKey'] + from .accounts import Account +from .api_keys import ApiKey +from .base import app from .billers import Biller from .cards import Card from .files import File diff --git a/examples/fastapi/resources/accounts.py b/examples/fastapi/resources/accounts.py new file mode 100644 index 00000000..b6965a29 --- /dev/null +++ b/examples/fastapi/resources/accounts.py @@ -0,0 +1,50 @@ +import datetime as dt + +from fastapi import Request +from fastapi.responses import JSONResponse as Response + +from agave.core.filters import generic_query + +from ...models import Account as AccountModel +from ...validators import ( + AccountQuery, + AccountRequest, + AccountResponse, + AccountUpdateRequest, +) +from .base import app + + +@app.resource('/accounts') +class Account: + model = AccountModel + query_validator = AccountQuery + update_validator = AccountUpdateRequest + get_query_filter = generic_query + response_model = AccountResponse + + @staticmethod + async def create(request: AccountRequest) -> Response: + """This is the description for openapi""" + account = AccountModel( + name=request.name, + user_id=app.current_user_id, + platform_id=app.current_platform_id, + ) + await account.async_save() + return Response(content=account.to_dict(), status_code=201) + + @staticmethod + async def update( + account: AccountModel, + request: AccountUpdateRequest, + ) -> Response: + account.name = request.name + await account.async_save() + return Response(content=account.to_dict(), status_code=200) + + @staticmethod + async def delete(account: AccountModel, _: Request) -> Response: + account.deactivated_at = dt.datetime.utcnow().replace(microsecond=0) + await account.async_save() + return Response(content=account.to_dict(), status_code=200) diff --git a/examples/fastapi/resources/api_keys.py b/examples/fastapi/resources/api_keys.py new file mode 100644 index 00000000..1e629c2d --- /dev/null +++ b/examples/fastapi/resources/api_keys.py @@ -0,0 +1,29 @@ +import datetime as dt + +from fastapi import Request +from fastapi.responses import JSONResponse as Response + +from agave.core.filters import generic_query + +from ...models import ApiKey as ApiKeyModel +from ...validators import ApiKeyRequest, ApiKeyResponse +from .base import app + + +@app.resource('/api_keys') +class ApiKey: + model = ApiKeyModel + response_model = ApiKeyResponse + + @staticmethod + async def create(request: ApiKeyRequest) -> Response: + ak = ApiKeyModel( + user=request.user, + password=request.password.get_secret_value(), + user_id=app.current_user_id, + platform_id=app.current_platform_id, + secret='My-super-secret-key', + ) + await ak.async_save() + + return Response(ak.to_dict(), status_code=201) diff --git a/examples/fastapi/resources/base.py b/examples/fastapi/resources/base.py new file mode 100644 index 00000000..a1686496 --- /dev/null +++ b/examples/fastapi/resources/base.py @@ -0,0 +1,29 @@ +from cuenca_validations.errors import WrongCredsError + +from agave.core.exc import UnauthorizedError +from agave.fastapi import RestApiBlueprint + +app = RestApiBlueprint() + + +@app.get('/healthy_auth') +def health_auth_check() -> dict: + return dict(greeting="I'm authenticated and healthy !!!") + + +@app.get('/raise_cuenca_errors') +def raise_cuenca_errors() -> None: + raise WrongCredsError('you are not lucky enough!') + + +@app.get('/raise_fast_agave_errors') +def raise_fast_agave_errors() -> None: + raise UnauthorizedError('nice try!') + + +@app.get('/you_shall_not_pass') +def you_shall_not_pass() -> None: + # Este endpoint nunca ser谩 ejecutado + # La prueba de este endpoint hace un mock a nivel middleware + # para responder con un `UnauthorizedError` + ... diff --git a/examples/fastapi/resources/billers.py b/examples/fastapi/resources/billers.py new file mode 100644 index 00000000..e5526496 --- /dev/null +++ b/examples/fastapi/resources/billers.py @@ -0,0 +1,12 @@ +from agave.core.filters import generic_query + +from ...models import Biller as BillerModel +from ...validators import BillerQuery +from .base import app + + +@app.resource('/billers') +class Biller: + model = BillerModel + query_validator = BillerQuery + get_query_filter = generic_query diff --git a/examples/fastapi/resources/cards.py b/examples/fastapi/resources/cards.py new file mode 100644 index 00000000..311d14b0 --- /dev/null +++ b/examples/fastapi/resources/cards.py @@ -0,0 +1,26 @@ +from fastapi.responses import JSONResponse as Response + +from agave.core.filters import generic_query + +from ...models import Card as CardModel +from ...validators import CardQuery +from .base import app + + +@app.resource('/cards') +class Card: + model = CardModel + query_validator = CardQuery + get_query_filter = generic_query + + @staticmethod + async def retrieve(card: CardModel) -> Response: + data = card.to_dict() + data['number'] = '*' * 16 + return Response(content=data) + + @staticmethod + async def query(response: dict): + for item in response['items']: + item['number'] = '*' * 16 + return response diff --git a/examples/fastapi/resources/files.py b/examples/fastapi/resources/files.py new file mode 100644 index 00000000..44833df9 --- /dev/null +++ b/examples/fastapi/resources/files.py @@ -0,0 +1,38 @@ +from io import BytesIO + +from fastapi import BackgroundTasks +from fastapi.responses import JSONResponse as Response + +from agave.core.filters import generic_query + +from ...models import File as FileModel +from ...validators import FileQuery, FileUploadValidator +from .base import app + + +def save_file_to_disk(file: bytes, name: str) -> None: + with open(name, 'wb') as out_file: + out_file.write(file) + + +@app.resource('/files') +class File: + model = FileModel + query_validator = FileQuery + upload_validator = FileUploadValidator + get_query_filter = generic_query + + @classmethod + async def download(cls, data: FileModel) -> BytesIO: + return BytesIO(bytes('Hello', 'utf-8')) + + @classmethod + async def upload( + cls, request: FileUploadValidator, background_tasks: BackgroundTasks + ) -> Response: + file = request.file + name = request.file_name + background_tasks.add_task(save_file_to_disk, file=file, name=name) + file = FileModel(name=name, user_id='US01') + await file.async_save() + return Response(content=file.to_dict(), status_code=201) diff --git a/examples/fastapi/resources/transactions.py b/examples/fastapi/resources/transactions.py new file mode 100644 index 00000000..86d70365 --- /dev/null +++ b/examples/fastapi/resources/transactions.py @@ -0,0 +1,13 @@ +from fastapi.responses import JSONResponse as Response + +from ...models.transactions import Transaction as TransactionModel +from .base import app + + +@app.resource('/transactions') +class Transaction: + model = TransactionModel + + @staticmethod + async def create() -> Response: + return Response(content={}, status_code=201) diff --git a/examples/fastapi/resources/users.py b/examples/fastapi/resources/users.py new file mode 100644 index 00000000..03d9f30b --- /dev/null +++ b/examples/fastapi/resources/users.py @@ -0,0 +1,25 @@ +from fastapi import Request +from fastapi.responses import JSONResponse as Response + +from agave.core.filters import generic_query + +from ...models import User as UserModel +from ...validators import UserQuery, UserUpdateRequest +from .base import app + + +@app.resource('/users') +class User: + model = UserModel + query_validator = UserQuery + get_query_filter = generic_query + update_validator = UserUpdateRequest + + @staticmethod + async def update( + user: UserModel, request: UserUpdateRequest, api_request: Request + ) -> Response: + user.name = request.name + user.ip = api_request.client.host if api_request.client else None + await user.async_save() + return Response(content=user.to_dict(), status_code=200) diff --git a/examples/models/__init__.py b/examples/models/__init__.py new file mode 100644 index 00000000..35135e58 --- /dev/null +++ b/examples/models/__init__.py @@ -0,0 +1,17 @@ +__all__ = [ + 'Account', + 'Biller', + 'Card', + 'Transaction', + 'File', + 'User', + 'ApiKey', +] + +from .accounts import Account +from .api_keys import ApiKey +from .billers import Biller +from .cards import Card +from .files import File +from .transactions import Transaction +from .users import User diff --git a/examples/chalicelib/models/accounts.py b/examples/models/accounts.py similarity index 52% rename from examples/chalicelib/models/accounts.py rename to examples/models/accounts.py index ea20fe59..99168c45 100644 --- a/examples/chalicelib/models/accounts.py +++ b/examples/models/accounts.py @@ -1,10 +1,10 @@ -from mongoengine import DateTimeField, Document, StringField +from cuenca_validations.types import uuid_field +from mongoengine import DateTimeField, StringField +from mongoengine_plus.aio import AsyncDocument +from mongoengine_plus.models import BaseModel -from agave.models import BaseModel -from agave.models.helpers import uuid_field - -class Account(BaseModel, Document): +class Account(BaseModel, AsyncDocument): id = StringField(primary_key=True, default=uuid_field('AC')) name = StringField(required=True) user_id = StringField(required=True) diff --git a/examples/models/api_keys.py b/examples/models/api_keys.py new file mode 100644 index 00000000..17317414 --- /dev/null +++ b/examples/models/api_keys.py @@ -0,0 +1,15 @@ +from cuenca_validations.types import uuid_field +from mongoengine import DateTimeField, StringField +from mongoengine_plus.aio import AsyncDocument +from mongoengine_plus.models import BaseModel + + +class ApiKey(BaseModel, AsyncDocument): + id = StringField(primary_key=True, default=uuid_field('AK')) + secret = StringField(required=True) + user = StringField(required=True) + password = StringField(required=True) + user_id = StringField(required=True) + platform_id = StringField(required=True) + created_at = DateTimeField() + deactivated_at = DateTimeField() diff --git a/examples/models/billers.py b/examples/models/billers.py new file mode 100644 index 00000000..192b917a --- /dev/null +++ b/examples/models/billers.py @@ -0,0 +1,12 @@ +import datetime as dt + +from cuenca_validations.types import uuid_field +from mongoengine import DateTimeField, StringField +from mongoengine_plus.aio import AsyncDocument +from mongoengine_plus.models import BaseModel + + +class Biller(BaseModel, AsyncDocument): + id = StringField(primary_key=True, default=uuid_field('BL')) + created_at = DateTimeField(default=dt.datetime.utcnow) + name = StringField(required=True) diff --git a/examples/models/cards.py b/examples/models/cards.py new file mode 100644 index 00000000..d4cfb2d0 --- /dev/null +++ b/examples/models/cards.py @@ -0,0 +1,11 @@ +from cuenca_validations.types import uuid_field +from mongoengine import DateTimeField, StringField +from mongoengine_plus.aio import AsyncDocument +from mongoengine_plus.models import BaseModel + + +class Card(BaseModel, AsyncDocument): + id = StringField(primary_key=True, default=uuid_field('CA')) + number = StringField(required=True) + user_id = StringField(required=True) + created_at = DateTimeField() diff --git a/examples/models/files.py b/examples/models/files.py new file mode 100644 index 00000000..46acdb8c --- /dev/null +++ b/examples/models/files.py @@ -0,0 +1,10 @@ +from cuenca_validations.types import uuid_field +from mongoengine import StringField +from mongoengine_plus.aio import AsyncDocument +from mongoengine_plus.models import BaseModel + + +class File(BaseModel, AsyncDocument): + id = StringField(primary_key=True, default=uuid_field('TR')) + user_id = StringField(required=True) + name = StringField(required=True) diff --git a/examples/models/transactions.py b/examples/models/transactions.py new file mode 100644 index 00000000..23f17e26 --- /dev/null +++ b/examples/models/transactions.py @@ -0,0 +1,10 @@ +from cuenca_validations.types import uuid_field +from mongoengine import FloatField, StringField +from mongoengine_plus.aio import AsyncDocument +from mongoengine_plus.models import BaseModel + + +class Transaction(BaseModel, AsyncDocument): + id = StringField(primary_key=True, default=uuid_field('TR')) + user_id = StringField(required=True) + amount = FloatField(required=True) diff --git a/examples/models/users.py b/examples/models/users.py new file mode 100644 index 00000000..0a859349 --- /dev/null +++ b/examples/models/users.py @@ -0,0 +1,14 @@ +import datetime as dt + +from cuenca_validations.types import uuid_field +from mongoengine import DateTimeField, StringField +from mongoengine_plus.aio import AsyncDocument +from mongoengine_plus.models import BaseModel + + +class User(BaseModel, AsyncDocument): + id = StringField(primary_key=True, default=uuid_field('US')) + created_at = DateTimeField(default=dt.datetime.utcnow) + name = StringField(required=True) + platform_id = StringField(required=True) + ip = StringField() diff --git a/examples/tasks/__init__.py b/examples/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/tasks/retry_task_example.py b/examples/tasks/retry_task_example.py new file mode 100644 index 00000000..d75bf3d4 --- /dev/null +++ b/examples/tasks/retry_task_example.py @@ -0,0 +1,29 @@ +import random + +from agave.core.exc import RetryTask +from agave.tasks.sqs_tasks import task + +# Esta URL es solo un mock de la queue. +# Debes reemplazarla con la URL de tu queue +QUEUE_URL = 'http://127.0.0.1:4000/123456789012/core.fifo' + + +class YouCanTryAgain(Exception): ... + + +def test_your_luck(message): + value = random.uniform(100) + if 0 < value <= 33: + print('you are lucky!', message) + elif 33 < value <= 66: + raise YouCanTryAgain + else: + raise Exception('game over! :(') + + +@task(queue_url=QUEUE_URL, region_name='us-east-1', max_retries=2) +async def dummy_retry_task(message) -> None: + try: + test_your_luck(message) + except YouCanTryAgain: + raise RetryTask diff --git a/examples/tasks/task_example.py b/examples/tasks/task_example.py new file mode 100644 index 00000000..04ec7b52 --- /dev/null +++ b/examples/tasks/task_example.py @@ -0,0 +1,31 @@ +from typing import Optional, Union + +from pydantic import BaseModel + +from agave.tasks.sqs_tasks import task + +# Esta URL es solo un mock de la queue. +# Debes reemplazarla con la URL de tu queue +QUEUE_URL = 'http://127.0.0.1:4000/123456789012/core.fifo' +QUEUE2_URL = 'http://127.0.0.1:4000/123456789012/validator.fifo' + + +class User(BaseModel): + name: str + age: int + nick_name: Optional[str] + + +class Company(BaseModel): + legal_name: str + rfc: str + + +@task(queue_url=QUEUE_URL, region_name='us-east-1') +async def dummy_task(message) -> None: + print(message) + + +@task(queue_url=QUEUE2_URL, region_name='us-east-1') +async def task_validator(message: Union[User, Company]) -> None: + print(message.model_dump()) diff --git a/examples/validators.py b/examples/validators.py new file mode 100644 index 00000000..73b3578c --- /dev/null +++ b/examples/validators.py @@ -0,0 +1,75 @@ +import datetime as dt +from typing import Optional + +from cuenca_validations.types import QueryParams +from pydantic import BaseModel, SecretStr + + +class AccountQuery(QueryParams): + name: Optional[str] = None + user_id: Optional[str] = None + platform_id: Optional[str] = None + active: Optional[bool] = None + + +class TransactionQuery(QueryParams): + user_id: Optional[str] = None + + +class BillerQuery(QueryParams): + name: str + + +class UserQuery(QueryParams): + platform_id: Optional[str] = None + + +class AccountRequest(BaseModel): + name: str + + +class AccountResponse(BaseModel): + id: str + name: str + user_id: str + platform_id: str + created_at: dt.datetime + deactivated_at: Optional[dt.datetime] = None + + +class AccountUpdateRequest(BaseModel): + name: str + + +class ApiKeyRequest(BaseModel): + user: str + password: SecretStr + short_secret: SecretStr + + +class ApiKeyResponse(BaseModel): + id: str + secret: SecretStr + user: str + password: SecretStr + user_id: str + platform_id: str + created_at: dt.datetime + deactivated_at: Optional[dt.datetime] = None + + +class FileQuery(QueryParams): + user_id: Optional[str] = None + + +class CardQuery(QueryParams): + number: Optional[str] = None + + +class FileUploadValidator(BaseModel): + file: bytes + file_name: str + + +class UserUpdateRequest(BaseModel): + name: str diff --git a/requirements-test.txt b/requirements-test.txt index e68fc34c..751866ed 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,11 +1,16 @@ -pytest==6.2.* -pytest-freezegun==0.4.* -pytest-cov==2.11.* -black==20.8b1 -isort==5.7.* -flake8==3.8.* -mypy==0.812 -pytest-chalice==0.0.* -mongomock==3.22.* -mock==4.0.3 -click==8.0.1 \ No newline at end of file +pytest==8.3.4 +pytest-cov==6.0.0 +black==24.10.0 +isort==5.13.2 +flake8==7.1.1 +mypy==1.14.1 +mongomock==4.3.0 +mock==5.1.0 +pytest-freezegun==0.4.2 +pytest-chalice==0.0.5 +click==8.1.8 +moto[server]==5.0.26 +pytest-vcr==1.0.2 +pytest-asyncio==0.18.* +requests==2.32.3 +httpx==0.28.1 diff --git a/requirements.txt b/requirements.txt index e62e5480..3ed31b9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,9 @@ -blinker==1.4 -chalice==1.25.0 -mongoengine==0.22.1 -cuenca-validations==0.11.22 -dnspython==2.1.0 +cuenca-validations==2.0.2 +chalice==1.31.3 +mongoengine==0.29.1 +fastapi==0.115.6 +mongoengine-plus==1.0.0 +python-multipart==0.0.20 +starlette-context==0.3.6 +aiobotocore==2.17.0 +types-aiobotocore-sqs==2.17.0 diff --git a/setup.cfg b/setup.cfg index 8d5ea6b0..5c8cd528 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,7 @@ test=pytest [tool:pytest] addopts = -p no:warnings -v --cov-report term-missing --cov=agave +asyncio_mode = auto [flake8] inline-quotes = ' diff --git a/setup.py b/setup.py index cc8e2ce3..c688ce89 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,6 @@ version = SourceFileLoader('version', 'agave/version.py').load_module() - with open('README.md', 'r') as f: long_description = f.read() @@ -21,16 +20,32 @@ packages=find_packages(), include_package_data=True, package_data=dict(agave=['py.typed']), - python_requires='>=3.8', + python_requires='>=3.9', install_requires=[ - 'chalice>=1.16.0,<1.25.1', - 'cuenca-validations>=0.9.0,<1.0.0', - 'blinker>=1.4,<1.5', - 'mongoengine>=0.20.0,<0.23.0', - 'dnspython>=2.0.0,<2.2.0', + 'cuenca-validations>=2.0.2,<3.0.0', + 'mongoengine>=0.29.0,<0.30.0', + 'mongoengine-plus>=1.0.0,<2.0.0', + 'python-multipart>=0.0.20,<0.0.30', ], + extras_require={ + 'chalice': [ + 'chalice>=1.30.0,<2.0.0', + ], + 'fastapi': [ + 'fastapi>=0.115.0,<1.0.0', + 'starlette-context>=0.3.2,<0.4.0', + ], + 'tasks': [ + 'aiobotocore>=2.0.0,<3.0.0', + 'types-aiobotocore-sqs>=2.1.0,<3.0.0', + ], + }, classifiers=[ - 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'License :: OSI Approved :: MIT License', 'Operating System :: OS Independent', ], diff --git a/test_file.txt b/test_file.txt new file mode 100644 index 00000000..e69de29b diff --git a/tests/blueprint/test_blueprint.py b/tests/blueprint/test_blueprint.py index 9d4d39bc..594b6b37 100644 --- a/tests/blueprint/test_blueprint.py +++ b/tests/blueprint/test_blueprint.py @@ -1,312 +1,643 @@ import datetime as dt -from typing import List +from tempfile import TemporaryFile +from unittest.mock import MagicMock, patch from urllib.parse import urlencode import pytest -from chalice.test import Client -from mock import MagicMock, patch +from fastapi.testclient import TestClient -from examples.chalicelib.models import Account, Card, File from examples.config import ( TEST_DEFAULT_PLATFORM_ID, TEST_DEFAULT_USER_ID, TEST_SECOND_PLATFORM_ID, ) - -PLATFORM_ID_FILTER_REQUIRED = ( - 'examples.chalicelib.blueprints.authed.' - 'AuthedBlueprint.platform_id_filter_required' -) - -USER_ID_FILTER_REQUIRED = ( - 'examples.chalicelib.blueprints.authed.' - 'AuthedBlueprint.user_id_filter_required' +from examples.models import Account, Card, File, User + +# Constants for both frameworks +FRAMEWORK_CONFIGS = { + 'chalice': { + 'platform_id_filter': ( + 'examples.chalice.blueprints.authed.' + 'AuthedBlueprint.platform_id_filter_required' + ), + 'user_id_filter': ( + 'examples.chalice.blueprints.authed.' + 'AuthedBlueprint.user_id_filter_required' + ), + 'bad_request_code': 400, + 'validation_error_code': 400, + 'method_not_allowed_code': 400, + 'not_found_code': 403, + }, + 'fastapi': { + 'platform_id_filter': ( + 'examples.fastapi.middlewares.' + 'AuthedMiddleware.required_platform_id' + ), + 'user_id_filter': ( + 'examples.fastapi.middlewares.AuthedMiddleware.required_user_id' + ), + 'bad_request_code': 422, + 'validation_error_code': 422, + 'method_not_allowed_code': 405, + 'not_found_code': 404, + }, +} + + +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] ) - - -def test_create_resource(client: Client) -> None: +def test_create_resource( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) data = dict(name='Doroteo Arango') - resp = client.http.post('/accounts', json=data) - model = Account.objects.get(id=resp.json_body['id']) - assert resp.status_code == 201 - assert model.to_dict() == resp.json_body + resp = client.post("/accounts", json=data) + json_body = resp.json() + status_code = resp.status_code + model = Account.objects.get(id=json_body['id']) + assert status_code == 201 + assert model.to_dict() == json_body model.delete() -def test_create_resource_bad_request(client: Client) -> None: +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) +def test_create_resource_bad_request( + client_fixture: str, framework_config: dict, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) data = dict(invalid_field='some value') - resp = client.http.post('/accounts', json=data) - assert resp.status_code == 400 - - -def test_retrieve_resource(client: Client, account: Account) -> None: - resp = client.http.get(f'/accounts/{account.id}') - assert resp.status_code == 200 - assert resp.json_body == account.to_dict() + resp = client.post('/accounts', json=data) + assert resp.status_code == framework_config['validation_error_code'] -@patch(PLATFORM_ID_FILTER_REQUIRED, MagicMock(return_value=True)) +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_retrieve_resource( + client_fixture: str, request: pytest.FixtureRequest, account: Account +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.get(f'/accounts/{account.id}') + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + assert json_body == account.to_dict() + + +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) def test_retrieve_resource_platform_id_filter_required( - client: Client, other_account: Account + client_fixture: str, + framework_config: dict, + request: pytest.FixtureRequest, + other_account: Account, ) -> None: - resp = client.http.get(f'/accounts/{other_account.id}') - assert resp.status_code == 404 - - -@patch(USER_ID_FILTER_REQUIRED, MagicMock(return_value=True)) + patch_target = framework_config["platform_id_filter"] + with patch(patch_target, MagicMock(return_value=True)): + client = request.getfixturevalue(client_fixture) + resp = client.get(f"/accounts/{other_account.id}") + assert resp.status_code == 404 + + +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) def test_retrieve_resource_user_id_filter_required( - client: Client, other_account: Account + client_fixture: str, + framework_config: dict, + request: pytest.FixtureRequest, + other_account: Account, ) -> None: - resp = client.http.get(f'/accounts/{other_account.id}') - assert resp.status_code == 404 - - -@patch(PLATFORM_ID_FILTER_REQUIRED, MagicMock(return_value=True)) -@patch(USER_ID_FILTER_REQUIRED, MagicMock(return_value=True)) + patch_target = framework_config["user_id_filter"] + with patch(patch_target, MagicMock(return_value=True)): + client = request.getfixturevalue(client_fixture) + resp = client.get(f"/accounts/{other_account.id}") + assert resp.status_code == 404 + + +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) def test_retrieve_resource_user_id_and_platform_id_filter_required( - client: Client, other_account: Account + client_fixture: str, + framework_config: dict, + request: pytest.FixtureRequest, + other_account: Account, ) -> None: - resp = client.http.get(f'/accounts/{other_account.id}') - assert resp.status_code == 404 + platform_id_filter_target = framework_config["platform_id_filter"] + user_id_filter_target = framework_config["user_id_filter"] + + with ( + patch(platform_id_filter_target, MagicMock(return_value=True)), + patch(user_id_filter_target, MagicMock(return_value=True)), + ): + client = request.getfixturevalue(client_fixture) + resp = client.get(f"/accounts/{other_account.id}") + assert resp.status_code == 404 -def test_retrieve_resource_not_found(client: Client) -> None: - resp = client.http.get('/accounts/unknown_id') + +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_retrieve_resource_not_found( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.get('/accounts/unknown_id') assert resp.status_code == 404 -def test_update_resource_with_invalid_params(client: Client) -> None: +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) +def test_update_resource_with_invalid_params( + client_fixture: str, framework_config: dict, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) wrong_params = dict(wrong_param='wrong_value') - response = client.http.patch( + response = client.patch( '/accounts/NOT_EXISTS', json=wrong_params, ) - assert response.status_code == 400 + assert response.status_code == framework_config['validation_error_code'] -def test_retrieve_custom_method(client: Client, card: Card) -> None: - resp = client.http.get(f'/cards/{card.id}') - assert resp.status_code == 200 - assert resp.json_body['number'] == '*' * 16 +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_retrieve_custom_method( + client_fixture: str, request: pytest.FixtureRequest, card: Card +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.get(f'/cards/{card.id}') + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + assert json_body['number'] == '*' * 16 -def test_update_resource_that_doesnt_exist(client: Client) -> None: - resp = client.http.patch( +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_update_resource_that_doesnt_exist( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.patch( '/accounts/5f9b4d0ff8d7255e3cc3c128', json=dict(name='Frida'), ) assert resp.status_code == 404 -def test_update_resource(client: Client, account: Account) -> None: - resp = client.http.patch( +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_update_resource( + client_fixture: str, request: pytest.FixtureRequest, account: Account +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.patch( f'/accounts/{account.id}', json=dict(name='Maria Felix'), ) + json_body = resp.json() + status_code = resp.status_code account.reload() - assert resp.json_body['name'] == 'Maria Felix' + assert json_body['name'] == 'Maria Felix' assert account.name == 'Maria Felix' - assert resp.status_code == 200 + assert status_code == 200 -def test_delete_resource(client: Client, account: Account) -> None: - resp = client.http.delete(f'/accounts/{account.id}') +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_delete_resource( + client_fixture: str, request: pytest.FixtureRequest, account: Account +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.delete(f'/accounts/{account.id}') + json_body = resp.json() + status_code = resp.status_code account.reload() - assert resp.status_code == 200 - assert resp.json_body['deactivated_at'] is not None + assert status_code == 200 + assert json_body['deactivated_at'] is not None assert account.deactivated_at is not None -def test_delete_resource_not_exists(client: Client) -> None: - resp = client.http.delete('/accounts/1234') +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_delete_resource_not_exists( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.delete('/accounts/1234') assert resp.status_code == 404 +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) @pytest.mark.usefixtures('accounts') -def test_query_count_resource(client: Client) -> None: +def test_query_count_resource( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) query_params = dict(count=1, name='Frida Kahlo') - response = client.http.get( - f'/accounts?{urlencode(query_params)}', - ) - assert response.status_code == 200 - assert response.json_body['count'] == 1 + resp = client.get(f'/accounts?{urlencode(query_params)}') + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + assert json_body['count'] == 1 +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) @pytest.mark.usefixtures('accounts') -def test_query_all_with_limit(client: Client) -> None: +def test_query_all_with_limit( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) limit = 2 query_params = dict(limit=limit) - response = client.http.get(f'/accounts?{urlencode(query_params)}') - assert response.status_code == 200 - assert len(response.json_body['items']) == limit - assert response.json_body['next_page_uri'] is None + resp = client.get(f'/accounts?{urlencode(query_params)}') + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + assert len(json_body['items']) == limit + assert json_body['next_page_uri'] is None +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) @pytest.mark.usefixtures('accounts') -def test_query_all_resource(client: Client, accounts: List[Account]) -> None: +def test_query_all_resource( + client_fixture: str, + request: pytest.FixtureRequest, + accounts: list[Account], +) -> None: + client = request.getfixturevalue(client_fixture) accounts = list(reversed(accounts)) items = [] page_uri = f'/accounts?{urlencode(dict(page_size=2))}' while page_uri: - resp = client.http.get(page_uri) - assert resp.status_code == 200 - items.extend(resp.json_body['items']) - page_uri = resp.json_body['next_page_uri'] + resp = client.get(page_uri) + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + items.extend(json_body['items']) + page_uri = json_body['next_page_uri'] assert len(items) == len(accounts) assert all(a.to_dict() == b for a, b in zip(accounts, items)) +@pytest.mark.parametrize("client_fixture", ["chalice_client"]) def test_query_all_filter_active( - client: Client, account: Account, accounts: List[Account] + client_fixture: str, + request: pytest.FixtureRequest, + account: Account, + accounts: list[Account], ) -> None: + client = request.getfixturevalue(client_fixture) query_params = dict(active=True) # Query active items - resp = client.http.get(f'/accounts?{urlencode(query_params)}') - assert resp.status_code == 200 - items = resp.json_body['items'] + resp = client.get(f'/accounts?{urlencode(query_params)}') + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + items = json_body['items'] assert len(items) == len(accounts) assert all(item['deactivated_at'] is None for item in items) # Deactivate Item account.deactivated_at = dt.datetime.utcnow() account.save() - resp = client.http.get(f'/accounts?{urlencode(query_params)}') - assert resp.status_code == 200 - items = resp.json_body['items'] + resp = client.get(f'/accounts?{urlencode(query_params)}') + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + items = json_body['items'] assert len(items) == len(accounts) - 1 # Query deactivated items query_params = dict(active=False) - resp = client.http.get(f'/accounts?{urlencode(query_params)}') - assert resp.status_code == 200 - items = resp.json_body['items'] + resp = client.get(f'/accounts?{urlencode(query_params)}') + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + items = json_body['items'] assert len(items) == 1 assert items[0]['deactivated_at'] is not None +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) def test_query_all_created_after( - client: Client, accounts: List[Account] + client_fixture: str, + request: pytest.FixtureRequest, + accounts: list[Account], ) -> None: + client = request.getfixturevalue(client_fixture) created_at = dt.datetime(2020, 2, 1) expected_length = len([a for a in accounts if a.created_at > created_at]) query_params = dict(created_after=created_at.isoformat()) - resp = client.http.get(f'/accounts?{urlencode(query_params)}') - - assert resp.status_code == 200 - assert len(resp.json_body['items']) == expected_length - - -@patch(PLATFORM_ID_FILTER_REQUIRED, MagicMock(return_value=True)) + resp = client.get(f'/accounts?{urlencode(query_params)}') + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + assert len(json_body['items']) == expected_length + + +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) def test_query_platform_id_filter_required( - client: Client, accounts: List[Account] + client_fixture: str, + framework_config: dict, + request: pytest.FixtureRequest, + accounts: list[Account], ) -> None: - accounts = list( - reversed( - [a for a in accounts if a.platform_id == TEST_DEFAULT_PLATFORM_ID] + client = request.getfixturevalue(client_fixture) + patch_target = framework_config["platform_id_filter"] + with patch(patch_target, MagicMock(return_value=True)): + accounts = list( + reversed( + [ + a + for a in accounts + if a.platform_id == TEST_DEFAULT_PLATFORM_ID + ] + ) ) - ) - items = [] - page_uri = f'/accounts?{urlencode(dict(page_size=2))}' + items = [] + page_uri = f'/accounts?{urlencode(dict(page_size=2))}' - while page_uri: - resp = client.http.get(page_uri) - assert resp.status_code == 200 - json_body = resp.json_body - items.extend(json_body['items']) - page_uri = json_body['next_page_uri'] + while page_uri: + resp = client.get(page_uri) + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + items.extend(json_body['items']) + page_uri = json_body['next_page_uri'] - assert len(items) == len(accounts) - assert all(a.to_dict() == b for a, b in zip(accounts, items)) + assert len(items) == len(accounts) + assert all(a.to_dict() == b for a, b in zip(accounts, items)) -@patch(USER_ID_FILTER_REQUIRED, MagicMock(return_value=True)) +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) def test_query_user_id_filter_required( - client: Client, accounts: List[Account] + client_fixture: str, + framework_config: dict, + request: pytest.FixtureRequest, + accounts: list[Account], ) -> None: - accounts = list( - reversed([a for a in accounts if a.user_id == TEST_DEFAULT_USER_ID]) - ) - items = [] - page_uri = f'/accounts?{urlencode(dict(page_size=2))}' - - while page_uri: - resp = client.http.get(page_uri) - assert resp.status_code == 200 - json_body = resp.json_body - items.extend(json_body['items']) - page_uri = json_body['next_page_uri'] - - assert len(items) == len(accounts) - assert all(a.to_dict() == b for a, b in zip(accounts, items)) - - -def test_query_resource_with_invalid_params(client: Client) -> None: + client = request.getfixturevalue(client_fixture) + patch_target = framework_config["user_id_filter"] + with patch(patch_target, MagicMock(return_value=True)): + accounts = list( + reversed( + [a for a in accounts if a.user_id == TEST_DEFAULT_USER_ID] + ) + ) + items = [] + page_uri = f'/accounts?{urlencode(dict(page_size=2))}' + + while page_uri: + resp = client.get(page_uri) + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + items.extend(json_body['items']) + page_uri = json_body['next_page_uri'] + + assert len(items) == len(accounts) + assert all(a.to_dict() == b for a, b in zip(accounts, items)) + + +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) +def test_query_resource_with_invalid_params( + client_fixture: str, request: pytest.FixtureRequest, framework_config: dict +) -> None: + client = request.getfixturevalue(client_fixture) wrong_params = dict(wrong_param='wrong_value') - response = client.http.get(f'/accounts?{urlencode(wrong_params)}') - assert response.status_code == 400 + resp = client.get(f'/accounts?{urlencode(wrong_params)}') + assert resp.status_code == framework_config['validation_error_code'] +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) @pytest.mark.usefixtures('cards') -def test_query_custom_method(client: Client) -> None: +def test_query_custom_method( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) query_params = dict(page_size=2) - resp = client.http.get(f'/cards?{urlencode(query_params)}') - assert resp.status_code == 200 - assert len(resp.json_body['items']) == 2 - assert all(card['number'] == '*' * 16 for card in resp.json_body['items']) - - resp = client.http.get(resp.json_body['next_page_uri']) - assert resp.status_code == 200 - assert len(resp.json_body['items']) == 2 - assert all(card['number'] == '*' * 16 for card in resp.json_body['items']) + resp = client.get(f'/cards?{urlencode(query_params)}') + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + assert len(json_body['items']) == 2 + assert all(card['number'] == '*' * 16 for card in json_body['items']) + + resp = client.get(json_body['next_page_uri']) + json_body = resp.json() + status_code = resp.status_code + assert status_code == 200 + assert len(json_body['items']) == 2 + assert all(card['number'] == '*' * 16 for card in json_body['items']) + + +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_cannot_create_resource( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.post('/billers', json=dict()) + assert resp.status_code == 405 -def test_cannot_create_resource(client: Client) -> None: - response = client.http.post('/transactions', json=dict()) - assert response.status_code == 405 +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) +def test_cannot_query_resource( + client_fixture: str, request: pytest.FixtureRequest, framework_config: dict +) -> None: + client = request.getfixturevalue(client_fixture) + query_params = dict(count=1, name='Frida Kahlo') + resp = client.get(f'/transactions?{urlencode(query_params)}') + assert resp.status_code == framework_config['method_not_allowed_code'] -def test_cannot_update_resource(client: Client) -> None: - response = client.http.post('/transactions', json=dict()) - assert response.status_code == 405 +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_cannot_update_resource( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.patch('/transactions/123', json=dict()) + assert resp.status_code == 405 -def test_cannot_delete_resource(client: Client) -> None: - resp = client.http.delete('/transactions/TR1234') +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_cannot_delete_resource( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.delete('/transactions/TR1234') assert resp.status_code == 405 -def test_download_resource(client: Client, file: File) -> None: +@pytest.mark.parametrize( + "client_fixture, framework_config", + [ + ("fastapi_client", FRAMEWORK_CONFIGS["fastapi"]), + ("chalice_client", FRAMEWORK_CONFIGS["chalice"]), + ], +) +def test_not_found( + client_fixture: str, request: pytest.FixtureRequest, framework_config: dict +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.get('/non-registered-endpoint') + assert resp.status_code == framework_config['not_found_code'] + + +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +def test_download_resource( + client_fixture: str, request: pytest.FixtureRequest, file: File +) -> None: + client = request.getfixturevalue(client_fixture) mimetype = 'application/pdf' - resp = client.http.get(f'/files/{file.id}', headers={'Accept': mimetype}) + resp = client.get(f'/files/{file.id}', headers={'Accept': mimetype}) assert resp.status_code == 200 assert resp.headers.get('Content-Type') == mimetype +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) @pytest.mark.usefixtures('users') -def test_filter_no_user_id_query(client: Client) -> None: - resp = client.http.get(f'/users?platform_id={TEST_DEFAULT_PLATFORM_ID}') - resp_json = resp.json_body - assert resp.status_code == 200 +def test_filter_no_user_id_query( + client_fixture: str, request: pytest.FixtureRequest +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.get(f'/users?platform_id={TEST_DEFAULT_PLATFORM_ID}') + resp_json = resp.json() + status_code = resp.status_code + assert status_code == 200 assert len(resp_json['items']) == 1 user1 = resp_json['items'][0] - resp = client.http.get(f'/users?platform_id={TEST_SECOND_PLATFORM_ID}') - resp_json = resp.json_body - assert resp.status_code == 200 + + resp = client.get(f'/users?platform_id={TEST_SECOND_PLATFORM_ID}') + resp_json = resp.json() + status_code = resp.status_code + assert status_code == 200 assert len(resp_json['items']) == 1 user2 = resp_json['items'][0] assert user1['id'] != user2['id'] +def test_update_user_with_ip(fastapi_client: TestClient, user: User) -> None: + resp = fastapi_client.patch( + f'/users/{user.id}', json={'name': 'Pedrito Sola'} + ) + resp_json = resp.json() + assert resp.status_code == 200 + assert resp_json['ip'] == 'testclient' + assert resp_json['name'] == 'Pedrito Sola' + + +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) @pytest.mark.usefixtures('billers') def test_filter_no_user_id_and_no_platform_id_query( - client: Client, + client_fixture: str, request: pytest.FixtureRequest ) -> None: - resp = client.http.get('/billers?name=ATT') - resp_json = resp.json_body - assert resp.status_code == 200 + client = request.getfixturevalue(client_fixture) + resp = client.get('/billers?name=ATT') + resp_json = resp.json() + status_code = resp.status_code + assert status_code == 200 assert len(resp_json['items']) == 1 + + +def test_upload_resource(fastapi_client: TestClient) -> None: + with TemporaryFile(mode='rb') as f: + file_body = f.read() + resp = fastapi_client.post( + '/files', + files=dict(file=(None, file_body), file_name=(None, 'test_file.txt')), + ) + assert resp.status_code == 201 + json = resp.json() + assert json['name'] == 'test_file.txt' + + +def test_upload_resource_with_invalid_form(fastapi_client: TestClient) -> None: + wrong_form = dict(another_file=b'Whasaaaaap') + resp = fastapi_client.post('/files', files=wrong_form) + assert resp.status_code == 400 diff --git a/tests/chalice/__init__.py b/tests/chalice/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/chalice/models/__init__.py b/tests/chalice/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/test_uuid_field_account.py b/tests/chalice/models/test_uuid_field_account.py similarity index 90% rename from tests/models/test_uuid_field_account.py rename to tests/chalice/models/test_uuid_field_account.py index caad9af4..5686b6e2 100644 --- a/tests/models/test_uuid_field_account.py +++ b/tests/chalice/models/test_uuid_field_account.py @@ -1,7 +1,7 @@ import uuid from base64 import urlsafe_b64encode -from agave.models.helpers import uuid_field_generic +from agave.chalice.models.helpers import uuid_field_generic def test_uuid_field_generic(): diff --git a/tests/chalice/test_imports.py b/tests/chalice/test_imports.py new file mode 100644 index 00000000..339fbdb2 --- /dev/null +++ b/tests/chalice/test_imports.py @@ -0,0 +1,20 @@ +import importlib +import sys + +import pytest + + +def test_chalice_import_error(monkeypatch): + for module in ['chalice', 'agave.chalice.rest_api']: + if module in sys.modules: + del sys.modules[module] + + monkeypatch.setitem(sys.modules, 'chalice', None) + + with pytest.raises(ImportError) as exc_info: + importlib.import_module('agave.chalice.rest_api') + + assert "You must install agave with [chalice] option" in str( + exc_info.value + ) + assert "pip install agave[chalice]" in str(exc_info.value) diff --git a/tests/conftest.py b/tests/conftest.py index 927b75f6..5e204b34 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,20 @@ import datetime as dt import functools -from typing import Callable, Generator, List +import json +from typing import Callable, Generator import pytest -from chalice.test import Client +from chalice.test import Client as OriginalChaliceClient +from fastapi.testclient import TestClient as FastAPIClient from mongoengine import Document -from examples.chalicelib.models import Account, Biller, Card, File, User from examples.config import ( TEST_DEFAULT_PLATFORM_ID, TEST_DEFAULT_USER_ID, TEST_SECOND_PLATFORM_ID, TEST_SECOND_USER_ID, ) - -from .helpers import accept_json +from examples.models import Account, Biller, Card, File, User FuncDecorator = Callable[..., Generator] @@ -22,7 +22,7 @@ def collection_fixture(model: Document) -> Callable[..., FuncDecorator]: def collection_decorator(func: Callable) -> FuncDecorator: @functools.wraps(func) - def wrapper(*args, **kwargs) -> Generator[List, None, None]: + def wrapper(*args, **kwargs) -> Generator[list, None, None]: items = func(*args, **kwargs) for item in items: item.save() @@ -34,27 +34,9 @@ def wrapper(*args, **kwargs) -> Generator[List, None, None]: return collection_decorator -@pytest.fixture() -def client() -> Generator[Client, None, None]: - from examples import app - - with Client(app) as client: - client.http.post = accept_json( # type: ignore[assignment] - client.http.post - ) - client.http.patch = accept_json( # type: ignore[assignment] - client.http.patch - ) - - client.http.delete = accept_json( # type: ignore[assignment] - client.http.delete - ) - yield client - - @pytest.fixture @collection_fixture(Account) -def accounts() -> List[Account]: +def accounts() -> list[Account]: return [ Account( name='Frida Kahlo', @@ -96,18 +78,23 @@ def accounts() -> List[Account]: @pytest.fixture -def account(accounts: List[Account]) -> Generator[Account, None, None]: +def account(accounts: list[Account]) -> Generator[Account, None, None]: yield accounts[0] @pytest.fixture -def other_account(accounts: List[Account]) -> Generator[Account, None, None]: +def user(users: list[User]) -> Generator[User, None, None]: + yield users[0] + + +@pytest.fixture +def other_account(accounts: list[Account]) -> Generator[Account, None, None]: yield accounts[-1] @pytest.fixture @collection_fixture(File) -def files() -> List[File]: +def files() -> list[File]: return [ File( name='Frida Kahlo', @@ -117,13 +104,13 @@ def files() -> List[File]: @pytest.fixture -def file(files: List[File]) -> Generator[File, None, None]: +def file(files: list[File]) -> Generator[File, None, None]: yield files[0] @pytest.fixture @collection_fixture(Card) -def cards() -> List[Card]: +def cards() -> list[Card]: return [ Card( number='5434000000000001', @@ -149,13 +136,13 @@ def cards() -> List[Card]: @pytest.fixture -def card(cards: List[Card]) -> Generator[Card, None, None]: +def card(cards: list[Card]) -> Generator[Card, None, None]: yield cards[0] @pytest.fixture @collection_fixture(User) -def users() -> List[User]: +def users() -> list[User]: return [ User(name='User1', platform_id=TEST_DEFAULT_PLATFORM_ID), User(name='User2', platform_id=TEST_SECOND_PLATFORM_ID), @@ -164,8 +151,68 @@ def users() -> List[User]: @pytest.fixture @collection_fixture(Biller) -def billers() -> List[Biller]: +def billers() -> list[Biller]: return [ Biller(name='Telcel'), Biller(name='ATT'), ] + + +@pytest.fixture +def fastapi_client() -> Generator[FastAPIClient, None, None]: + from examples.fastapi.app import app + + client = FastAPIClient(app) + yield client + + +class ChaliceResponse: + def __init__(self, chalice_response): + self._response = chalice_response + self._json_body = chalice_response.json_body + self._status_code = chalice_response.status_code + self._headers = chalice_response.headers + + def json(self): + return self._json_body + + @property + def status_code(self): + return self._status_code + + @property + def headers(self): + return self._headers + + +class ChaliceClient(OriginalChaliceClient): + def _request_with_json( + self, method: str, url: str, **kwargs + ) -> ChaliceResponse: + body = json.dumps(kwargs.pop('json')) if 'json' in kwargs else None + headers = {'Content-Type': 'application/json'} + response = getattr(self.http, method)( + url, body=body, headers=headers, **kwargs + ) + return ChaliceResponse(response) + + def post(self, url: str, **kwargs) -> ChaliceResponse: + return self._request_with_json('post', url, **kwargs) + + def get(self, url: str, **kwargs) -> ChaliceResponse: + response = self.http.get(url, **kwargs) + return ChaliceResponse(response) + + def patch(self, url: str, **kwargs) -> ChaliceResponse: + return self._request_with_json('patch', url, **kwargs) + + def delete(self, url: str, **kwargs) -> ChaliceResponse: + return self._request_with_json('delete', url, **kwargs) + + +@pytest.fixture() +def chalice_client() -> Generator[ChaliceClient, None, None]: + from examples.chalice import app + + client = ChaliceClient(app) + yield client diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/blueprint/test_decorators.py b/tests/core/test_decorators.py similarity index 84% rename from tests/blueprint/test_decorators.py rename to tests/core/test_decorators.py index 5d78b08e..052d0c0a 100644 --- a/tests/blueprint/test_decorators.py +++ b/tests/core/test_decorators.py @@ -1,6 +1,6 @@ from functools import wraps -from agave.blueprints.decorators import copy_attributes +from agave.core.blueprints.decorators import copy_attributes def i_am_test(func): @@ -19,8 +19,7 @@ def retrieve(self) -> str: def test_copy_properties_from() -> None: - def retrieve(): - ... + def retrieve(): ... assert not hasattr(retrieve, 'i_am_test') retrieve = copy_attributes(TestResource)(retrieve) diff --git a/tests/blueprint/test_filters.py b/tests/core/test_filters.py similarity index 91% rename from tests/blueprint/test_filters.py rename to tests/core/test_filters.py index b9a58849..2b87d6c9 100644 --- a/tests/blueprint/test_filters.py +++ b/tests/core/test_filters.py @@ -2,7 +2,7 @@ from cuenca_validations.types import QueryParams -from agave.filters import generic_query +from agave.core.filters import generic_query def test_generic_query_before(): diff --git a/tests/fastapi/__init__.py b/tests/fastapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fastapi/test_app.py b/tests/fastapi/test_app.py new file mode 100644 index 00000000..fb67b6c9 --- /dev/null +++ b/tests/fastapi/test_app.py @@ -0,0 +1,38 @@ +from unittest.mock import AsyncMock + +from _pytest.monkeypatch import MonkeyPatch +from fastapi.testclient import TestClient + +from agave.core.exc import UnauthorizedError +from examples.fastapi.middlewares.authed import AuthedMiddleware + + +def test_iam_healthy(fastapi_client: TestClient) -> None: + resp = fastapi_client.get('/') + assert resp.status_code == 200 + assert resp.json() == dict(greeting="I'm healthy!!!") + + +def test_cuenca_error_handler(fastapi_client: TestClient) -> None: + resp = fastapi_client.get('/raise_cuenca_errors') + assert resp.status_code == 401 + assert resp.json() == dict(error='you are not lucky enough!', code=101) + + +def test_fast_agave_error_handler(fastapi_client: TestClient) -> None: + resp = fastapi_client.get('/raise_fast_agave_errors') + assert resp.status_code == 401 + assert resp.json() == dict(error='nice try!') + + +def test_fast_agave_error_handler_from_middleware( + fastapi_client: TestClient, monkeypatch: MonkeyPatch +) -> None: + monkeypatch.setattr( + AuthedMiddleware, + 'authorize', + AsyncMock(side_effect=UnauthorizedError('come back to the shadows!')), + ) + resp = fastapi_client.get('/you_shall_not_pass') + assert resp.status_code == 401 + assert resp.json() == dict(error='come back to the shadows!') diff --git a/tests/fastapi/test_imports.py b/tests/fastapi/test_imports.py new file mode 100644 index 00000000..dfb50430 --- /dev/null +++ b/tests/fastapi/test_imports.py @@ -0,0 +1,20 @@ +import importlib +import sys + +import pytest + + +def test_fast_import_error(monkeypatch): + for module in ['fastapi', 'agave.fastapi.rest_api']: + if module in sys.modules: + del sys.modules[module] + + monkeypatch.setitem(sys.modules, 'fastapi', None) + + with pytest.raises(ImportError) as exc_info: + importlib.import_module('agave.fastapi.rest_api') + + assert "You must install agave with [fastapi] option" in str( + exc_info.value + ) + assert "pip install agave[fastapi]" in str(exc_info.value) diff --git a/tests/helpers.py b/tests/helpers.py deleted file mode 100644 index 106f4d20..00000000 --- a/tests/helpers.py +++ /dev/null @@ -1,22 +0,0 @@ -import functools -import json as jsonlib -from typing import Callable, Dict, Generator - -FuncDecorator = Callable[..., Generator] - - -def auth_header(username: str, password: str) -> Dict: - creds = username + password - return { - 'Authorization': f'Basic {creds}', - } - - -def accept_json(func: Callable) -> Callable: - @functools.wraps(func) - def wrapper(path, json=None, **kwargs): - body = jsonlib.dumps(json) if json else None - headers = {'Content-Type': 'application/json'} - return func(path, body=body, headers=headers, **kwargs) - - return wrapper diff --git a/tests/lib/test_model_helpers.py b/tests/lib/test_model_helpers.py deleted file mode 100644 index 9eeb467e..00000000 --- a/tests/lib/test_model_helpers.py +++ /dev/null @@ -1,103 +0,0 @@ -import re -from datetime import datetime as dt -from decimal import Decimal -from enum import Enum -from typing import ClassVar - -import pytest -from mongoengine import ( - BooleanField, - ComplexDateTimeField, - DateTimeField, - DecimalField, - DictField, - Document, - EmbeddedDocument, - EmbeddedDocumentField, - FloatField, - GenericLazyReferenceField, - IntField, - LazyReferenceField, - ListField, - StringField, -) - -from agave.lib.mongoengine.enum_field import EnumField -from agave.lib.mongoengine.model_helpers import mongo_to_dict -from agave.models.base import BaseModel - - -class Reference(Document, BaseModel): - meta: ClassVar = { - 'collection': 'references', - } - - -class EnumType(Enum): - member = 'name' - - -class Embedded(EmbeddedDocument, BaseModel): - name = StringField() - - -class TestModel(Document, BaseModel): - str_field = StringField() - int_field = IntField(default=1) - float_field = FloatField(default=1.1) - decimal_field = DecimalField(default=1.2) - boolean_field = BooleanField(default=True) - dict_field = DictField(default=dict(one=1, two=2)) - date_time_field = DateTimeField(default=dt.now) - complex_date_time_field = ComplexDateTimeField(default=dt.now) - enum_field = EnumField(EnumType, default=EnumType.member) - list_field = ListField(IntField(), default=lambda: [42]) - enum_list_field = ListField(EnumField(EnumType), default=[EnumType.member]) - embedded_list_field = ListField(EmbeddedDocumentField(Embedded)) - embedded_field = EmbeddedDocumentField(Embedded) - lazzy_field = LazyReferenceField(Reference) - lazzy_list_field = ListField(LazyReferenceField(Reference)) - generic_lazzy_field = GenericLazyReferenceField() - - __test__ = False - - -@pytest.fixture -def model(): - reference = Reference() - reference.save() - model = TestModel( - embedded_list_field=[Embedded(name='')], - lazzy_field=reference, - lazzy_list_field=[reference], - ) - model.save() - model.reload() - return model - - -def test_mongo_to_dict(model): - assert not mongo_to_dict(None) - model_dict = mongo_to_dict(model, exclude_fields=['str_field']) - - assert 'id' in model_dict - assert 'date_time_field' in model_dict - assert 'complex_date_time_field' in model_dict - assert model_dict['int_field'] == 1 - assert model_dict['float_field'] == '1.1' - assert model_dict['decimal_field'] == Decimal('1.2') - assert model_dict['dict_field']['one'] == 1 - assert model_dict['enum_field'] == 'name' - assert model_dict['boolean_field'] is True - assert model_dict['list_field'] == ['42'] - assert model_dict['enum_list_field'] == ['name'] - assert model_dict['embedded_list_field'] == [{'name': ''}] - assert model_dict['embedded_field'] == {} - reference_reg = re.compile(r'\/references\/.{24}') - assert reference_reg.match(model_dict['lazzy_field_uri']) - assert len(model_dict['lazzy_list_field_uris']) == 1 - assert all( - reference_reg.match(field) - for field in model_dict['lazzy_list_field_uris'] - ) - assert model_dict['generic_lazzy_field_uri'] is None diff --git a/tests/models/test_base.py b/tests/models/test_base.py deleted file mode 100644 index d86f6815..00000000 --- a/tests/models/test_base.py +++ /dev/null @@ -1,17 +0,0 @@ -from mongoengine import Document, StringField - -from agave.models import BaseModel - - -class TestModel(BaseModel, Document): - id = StringField() - secret_field = StringField() - __test__ = False - _hidden = ['secret_field'] - - -def test_hide_field(): - model = TestModel(id='12345', secret_field='secret') - model_dict = model.to_dict() - assert model_dict['secret_field'] == '********' - assert model_dict['id'] == '12345' diff --git a/tests/models/test_event_handlers.py b/tests/models/test_event_handlers.py deleted file mode 100644 index 9dae63f3..00000000 --- a/tests/models/test_event_handlers.py +++ /dev/null @@ -1,22 +0,0 @@ -import datetime as dt - -import pytest -from mongoengine import Document - -from agave.lib.mongoengine.event_handlers import updated_at - - -@updated_at.apply -class TestModel(Document): - __test__ = False - - -@pytest.mark.freeze_time('2020-10-10') -def test_attach_updated_at_field(): - model = TestModel() - with pytest.raises(AttributeError): - getattr(model, 'updated_at') - - model.save() - assert type(model.updated_at) is dt.datetime - assert model.updated_at == dt.datetime(2020, 10, 10) diff --git a/tests/tasks/__init__.py b/tests/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tasks/conftest.py b/tests/tasks/conftest.py new file mode 100644 index 00000000..28bec091 --- /dev/null +++ b/tests/tasks/conftest.py @@ -0,0 +1,84 @@ +import os +from functools import partial +from typing import Generator + +import aiobotocore +import boto3 +import pytest +from _pytest.monkeypatch import MonkeyPatch +from aiobotocore.session import AioSession + +from agave.tasks import sqs_tasks + + +@pytest.fixture(scope='session') +def aws_credentials() -> None: + """Mocked AWS Credentials for moto.""" + os.environ['AWS_ACCESS_KEY_ID'] = 'testing' + os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing' + os.environ['AWS_SECURITY_TOKEN'] = 'testing' + os.environ['AWS_DEFAULT_REGION'] = 'us-east-1' + boto3.setup_default_session() + + +@pytest.fixture(scope='session') +def aws_endpoint_urls( + aws_credentials, +) -> Generator[dict[str, str], None, None]: + from moto.server import ThreadedMotoServer + + server = ThreadedMotoServer(port=4000) + server.start() + + endpoints = dict( + sqs='http://127.0.0.1:4000/', + ) + yield endpoints + + server.stop() + + +@pytest.fixture(autouse=True) +def patch_tasks_count(monkeypatch: MonkeyPatch) -> None: + def one_loop(*_, **__): + # Para pruebas solo unos cuantos ciclos + for i in range(5): + yield i + + monkeypatch.setattr(sqs_tasks, 'count', one_loop) + + +@pytest.fixture(autouse=True) +def patch_create_client(aws_endpoint_urls, monkeypatch: MonkeyPatch) -> None: + create_client = AioSession.create_client + + def mock_create_client(*args, **kwargs): + service_name = next(a for a in args if type(a) is str) + kwargs['endpoint_url'] = aws_endpoint_urls[service_name] + + return create_client(*args, **kwargs) + + monkeypatch.setattr(AioSession, 'create_client', mock_create_client) + + +@pytest.fixture +async def sqs_client(): + session = aiobotocore.session.get_session() + async with session.create_client('sqs', 'us-east-1') as sqs: + await sqs.create_queue( + QueueName='core.fifo', + Attributes={ + 'FifoQueue': 'true', + 'ContentBasedDeduplication': 'true', + }, + ) + resp = await sqs.get_queue_url(QueueName='core.fifo') + sqs.send_message = partial(sqs.send_message, QueueUrl=resp['QueueUrl']) + sqs.receive_message = partial( + sqs.receive_message, + QueueUrl=resp['QueueUrl'], + AttributeNames=['ApproximateReceiveCount'], + ) + sqs.queue_url = resp['QueueUrl'] + yield sqs + await sqs.purge_queue(QueueUrl=resp['QueueUrl']) diff --git a/tests/tasks/test_imports.py b/tests/tasks/test_imports.py new file mode 100644 index 00000000..ac59dbb2 --- /dev/null +++ b/tests/tasks/test_imports.py @@ -0,0 +1,20 @@ +import importlib +import sys + +import pytest + + +def test_tasks_import_error(monkeypatch): + for module in ['types_aiobotocore_sqs', 'agave.tasks.sqs_client']: + if module in sys.modules: + del sys.modules[module] + + monkeypatch.setitem(sys.modules, 'types_aiobotocore_sqs', None) + + with pytest.raises(ImportError) as exc_info: + importlib.import_module('agave.tasks.sqs_client') + + assert "You must install agave with [fastapi, tasks] option" in str( + exc_info.value + ) + assert "pip install agave[fastapi, tasks]" in str(exc_info.value) diff --git a/tests/tasks/test_sqs_celery_client.py b/tests/tasks/test_sqs_celery_client.py new file mode 100644 index 00000000..bb0df9a0 --- /dev/null +++ b/tests/tasks/test_sqs_celery_client.py @@ -0,0 +1,56 @@ +import base64 +import json + +from agave.tasks.sqs_celery_client import SqsCeleryClient + +CORE_QUEUE_REGION = 'us-east-1' + + +async def test_send_task(sqs_client) -> None: + args = [10, 'foo'] + kwargs = dict(hola='mundo') + queue = SqsCeleryClient(sqs_client.queue_url, CORE_QUEUE_REGION) + await queue.start() + + await queue.send_task('some.task', args=args, kwargs=kwargs) + sqs_message = await sqs_client.receive_message() + encoded_body = sqs_message['Messages'][0]['Body'] + message = json.loads( + base64.b64decode(encoded_body.encode('utf-8')).decode() + ) + body_json = json.loads( + base64.b64decode(message['body'].encode('utf-8')).decode() + ) + + assert body_json[0] == args + assert body_json[1] == kwargs + assert message['headers']['lang'] == 'py' + assert message['headers']['task'] == 'some.task' + await queue.close() + + +async def test_send_background_task(sqs_client) -> None: + args = [10, 'foo'] + kwargs = dict(hola='mundo') + queue = SqsCeleryClient(sqs_client.queue_url, CORE_QUEUE_REGION) + await queue.start() + + assert len(queue.background_tasks) == 0 + + task = queue.send_background_task('some.task', args=args, kwargs=kwargs) + await task + sqs_message = await sqs_client.receive_message() + encoded_body = sqs_message['Messages'][0]['Body'] + message = json.loads( + base64.b64decode(encoded_body.encode('utf-8')).decode() + ) + body_json = json.loads( + base64.b64decode(message['body'].encode('utf-8')).decode() + ) + + assert body_json[0] == args + assert body_json[1] == kwargs + assert message['headers']['lang'] == 'py' + assert message['headers']['task'] == 'some.task' + await queue.close() + assert len(queue.background_tasks) == 0 diff --git a/tests/tasks/test_sqs_client.py b/tests/tasks/test_sqs_client.py new file mode 100644 index 00000000..f5849dac --- /dev/null +++ b/tests/tasks/test_sqs_client.py @@ -0,0 +1,35 @@ +import json + +from agave.tasks.sqs_client import SqsClient + +CORE_QUEUE_REGION = 'us-east-1' + + +async def test_send_message(sqs_client) -> None: + data1 = dict(hola='mundo') + data2 = dict(foo='bar') + + async with SqsClient(sqs_client.queue_url, CORE_QUEUE_REGION) as sqs: + await sqs.send_message(data1) + await sqs.send_message(data2, message_group_id='12345') + + sqs_message = await sqs_client.receive_message() + message = json.loads(sqs_message['Messages'][0]['Body']) + assert message == data1 + + sqs_message = await sqs_client.receive_message() + message = json.loads(sqs_message['Messages'][0]['Body']) + assert message == data2 + + +async def test_send_message_async(sqs_client) -> None: + data1 = dict(hola='mundo') + + async with SqsClient(sqs_client.queue_url, CORE_QUEUE_REGION) as sqs: + task = sqs.send_message_async(data1) + await task + + sqs_message = await sqs_client.receive_message() + message = json.loads(sqs_message['Messages'][0]['Body']) + + assert message == data1 diff --git a/tests/tasks/test_sqs_tasks.py b/tests/tasks/test_sqs_tasks.py new file mode 100644 index 00000000..3d6676a5 --- /dev/null +++ b/tests/tasks/test_sqs_tasks.py @@ -0,0 +1,429 @@ +import asyncio +import datetime as dt +import json +import uuid +from typing import Union +from unittest.mock import AsyncMock, call, patch + +import aiobotocore.client +from aiobotocore.httpsession import HTTPClientError +from pydantic import BaseModel + +from agave.core.exc import RetryTask +from agave.tasks.sqs_tasks import ( + BACKGROUND_TASKS, + get_running_fast_agave_tasks, + task, +) + +CORE_QUEUE_REGION = 'us-east-1' + + +async def test_execute_tasks(sqs_client) -> None: + """ + Happy path: Se obtiene el mensaje y se ejecuta el task exitosamente. + El mensaje debe ser eliminado autom谩ticamente del queue + """ + test_message = dict(id='abc123', name='fast-agave') + + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='1234', + ) + + async_mock_function = AsyncMock() + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + )(my_task)() + async_mock_function.assert_called_with(test_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + +async def test_execute_tasks_with_validator(sqs_client) -> None: + class Validator(BaseModel): + id: str + name: str + + async_mock_function = AsyncMock(return_value=None) + + async def my_task(data: Validator) -> None: + await async_mock_function(data) + + task_params = dict( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + ) + # Invalid body, not execute function + await sqs_client.send_message( + MessageBody=json.dumps(dict(foo='bar')), + MessageGroupId='4321', + ) + await task(**task_params)(my_task)() + assert async_mock_function.call_count == 0 + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + + # Body approve validator, function receive Validator + test_message = Validator(id='abc123', name='fast-agave') + await sqs_client.send_message( + MessageBody=test_message.json(), + MessageGroupId='1234', + ) + await task(**task_params)(my_task)() + async_mock_function.assert_called_with(test_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + +async def test_execute_tasks_with_union_validator(sqs_client) -> None: + class User(BaseModel): + id: str + name: str + + class Company(BaseModel): + id: str + legal_name: str + rfc: str + + async_mock_function = AsyncMock(return_value=None) + + async def my_task(data: Union[User, Company]) -> None: + await async_mock_function(data.model_dump()) + + task_params = dict( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + ) + # Invalid body, not execute function + test_message = dict(id='ID123', name='Sor Juana In茅s de la Cruz') + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='4321', + ) + await task(**task_params)(my_task)() + async_mock_function.assert_called_with(test_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + async_mock_function.reset_mock() + test_message = dict(id='ID123', legal_name='FastAgave', rfc='FA') + + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='54321', + ) + await task(**task_params)(my_task)() + async_mock_function.assert_called_with(test_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + +async def test_not_execute_tasks(sqs_client) -> None: + """ + Este caso es cuando el queue est谩 vac铆o. No hay nada que ejecutar + """ + async_mock_function = AsyncMock() + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + # No escribimos un mensaje en el queue + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + )(my_task)() + async_mock_function.assert_not_called() + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + +async def test_http_client_error_tasks(sqs_client) -> None: + """ + Este test prueba el caso cuando hay un error de conexi贸n al intentar + obtener recibir el mensaje del queue. Se maneja correctamente la + excepci贸n `HTTPClientError` para evitar que el loop que consume mensajes + se rompe inesperadamente. + """ + + test_message = dict(id='abc123', name='fast-agave') + + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='1234', + ) + + original_create_client = aiobotocore.client.AioClientCreator.create_client + + # Esta funci贸n hace un patch de la funci贸n `receive_message` para simular + # un error de conexi贸n, la recuperaci贸n de la conexi贸n y posteriores + # recepciones de mensajes sin body del queue. + async def mock_create_client(*args, **kwargs): + client = await original_create_client(*args, **kwargs) + client.receive_message = AsyncMock( + side_effect=[ + HTTPClientError(error='[Errno 104] Connection reset by peer'), + await sqs_client.receive_message(), + dict(), + dict(), + dict(), + ] + ) + return client + + async_mock_function = AsyncMock(return_value=None) + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + with patch( + 'aiobotocore.client.AioClientCreator.create_client', mock_create_client + ): + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=3, + max_retries=1, + )(my_task)() + async_mock_function.assert_called_once() + + +async def test_retry_tasks_default_max_retries(sqs_client) -> None: + """ + Este test prueba la l贸gica de reintentos con la configuraci贸n default, + es decir `max_retries=1` + + En este caso el task debe ejecutarse 2 veces + (la ejecuci贸n normal + max_retries) + + Se ejecuta este n煤mero de veces para ser consistentes con la l贸gica + de reintentos de Celery + """ + test_message = dict(id='abc123', name='fast-agave') + + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='1234', + ) + + async_mock_function = AsyncMock(side_effect=RetryTask) + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + )(my_task)() + + expected_calls = [call(test_message)] * 2 + async_mock_function.assert_has_calls(expected_calls) + assert async_mock_function.call_count == len(expected_calls) + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + + +async def test_retry_tasks_custom_max_retries(sqs_client) -> None: + """ + Este test prueba la l贸gica de reintentos con la configuraci贸n default, + es decir `max_retries=2` + + En este caso el task debe ejecutarse 3 veces + (la ejecuci贸n normal + max_retries) + """ + test_message = dict(id='abc123', name='fast-agave') + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='1234', + ) + + async_mock_function = AsyncMock(side_effect=RetryTask) + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=2, + )(my_task)() + + expected_calls = [call(test_message)] * 3 + async_mock_function.assert_has_calls(expected_calls) + assert async_mock_function.call_count == len(expected_calls) + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + +async def test_does_not_retry_on_unhandled_exceptions(sqs_client) -> None: + """ + Este caso prueba que las excepciones no controladas no se reintentan por + default (comportamiento consistente con Celery) + + Dentro de task deben manejarse las excepciones esperadas (como desconexi贸n + de la red). V茅ase los ejemplos de c贸mo aplicar este tipo de reintentos + """ + test_message = dict(id='abc123', name='fast-agave') + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='1234', + ) + + async_mock_function = AsyncMock( + side_effect=Exception('something went wrong :(') + ) + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=3, + )(my_task)() + + async_mock_function.assert_called_with(test_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + +async def test_retry_tasks_with_countdown(sqs_client) -> None: + """ + Este test prueba la l贸gica de reintentos con un countdown, + es decir, se modifica el visibility timeout del mensaje para que pueda + simularse un delay en la recepci贸n del mensaje por el siguiente + `receive_message` + + En este caso el task debe ejecutarse 2 veces + (la ejecuci贸n normal + 1 intento), sin embargo, + despu茅s de ejecutarse por primera vez deben pasar aprox 2 segundos + para que se ejecute el primer intento + + El par谩metro es similar a `self.retry(exc, countdown=10)` en celery + """ + test_message = dict(id='abc123', name='fast-agave') + + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='1234', + ) + + async_mock_function = AsyncMock(side_effect=RetryTask(countdown=2)) + + async def countdown_tester(data: dict): + await async_mock_function(data, dt.datetime.now()) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + )(countdown_tester)() + + call_times = [arg[1] for arg, _ in async_mock_function.call_args_list] + assert async_mock_function.call_count == 2 + assert call_times[1] - call_times[0] >= dt.timedelta(seconds=2) + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + + +async def test_concurrency_controller( + sqs_client, +) -> None: + message_id = str(uuid.uuid4()) + test_message = dict(id=message_id, name='fast-agave') + for i in range(5): + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId=message_id, + ) + + async_mock_function = AsyncMock() + + async def task_counter(data: dict) -> None: + await asyncio.sleep(5) + running_tasks = len(await get_running_fast_agave_tasks()) + await async_mock_function(running_tasks) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=3, + max_concurrent_tasks=2, + )(task_counter)() + + running_tasks = [call[0] for call, _ in async_mock_function.call_args_list] + assert max(running_tasks) == 2 + + +async def test_invalid_json_message(sqs_client) -> None: + """ + Este test verifica que los mensajes con JSON inv谩lido son ignorados + y el mensaje es eliminado del queue sin ejecutar el task + """ + # Enviamos un mensaje con JSON inv谩lido + await sqs_client.send_message( + MessageBody='{invalid_json', + MessageGroupId='1234', + ) + + async_mock_function = AsyncMock() + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + )(my_task)() + + # Verificamos que el task nunca fue ejecutado + async_mock_function.assert_not_called() + + # Verificamos que el mensaje fue eliminado del queue + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0