diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5bfbfd0..ab2b234 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,6 +57,6 @@ jobs: pip install -U setuptools python -m pip install -U pip - - run: pip install openapi-spec-validator . + - run: pip install openapi-spec-validator pydantic . - run: cd compliance && QUART_APP=todo:app quart schema | openapi-spec-validator - diff --git a/README.rst b/README.rst index 7fd0c16..4d45dbf 100644 --- a/README.rst +++ b/README.rst @@ -7,6 +7,10 @@ Quart-Schema is a Quart extension that provides schema validation and auto-generated API documentation. This is particularly useful when writing RESTful APIs. +Quart-Schema can use either `msgspec +`_ or `pydantic +`_ to validate. + Quickstart ---------- diff --git a/docs/conf.py b/docs/conf.py index 1569c78..7ccae52 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,7 +34,7 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon'] +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx_tabs.tabs'] source_suffix = '.rst' diff --git a/docs/discussion/dataclass_or_basemodel.rst b/docs/discussion/dataclass_or_basemodel.rst deleted file mode 100644 index 33a8101..0000000 --- a/docs/discussion/dataclass_or_basemodel.rst +++ /dev/null @@ -1,46 +0,0 @@ -Dataclass or BaseModel -====================== - -Pydantic's documentation primarily adopts the ``BaseModel`` approach, -i.e. - -.. code-block:: python - - from pydantic import BaseModel - - class Item(BaseModel): - ... - -rather than the pydantic-dataclass approach, - -.. code-block:: python - - from pydantic.dataclasses import dataclass - - @dataclass - class Item: - ... - -or the stdlib-dataclass approach, - -.. code-block:: python - - from dataclasses import dataclass - - @dataclass - class Item: - ... - -and whilst Quart-Schema supports all this documentation primarily -adopts the stdlib-dataclass approach. This is because I find this -approach to be cleaner and clearer. I think if pydantic had started -when after ``dataclass`` was added to the Python stdlib it would have -done the same. - -.. warning:: - - Just a caveat, that these two approaches lead to potentially - subtle differences which you can read about `here - `_. Should - you have issues with the stdlib dataclass try switching to the - pydantic dataclass. diff --git a/docs/discussion/index.rst b/docs/discussion/index.rst index f591e81..19d9b88 100644 --- a/docs/discussion/index.rst +++ b/docs/discussion/index.rst @@ -7,5 +7,4 @@ Discussions casing.rst class_or_function.rst - dataclass_or_basemodel.rst json_encoding.rst diff --git a/docs/how_to_guides/configuration.rst b/docs/how_to_guides/configuration.rst index 32bb549..c47f92e 100644 --- a/docs/how_to_guides/configuration.rst +++ b/docs/how_to_guides/configuration.rst @@ -5,15 +5,16 @@ The following configuration options are used by Quart-Schema. They should be set as part of the standard `Quart configuration `_. -============================= ===== -Configuration key type ------------------------------ ----- -QUART_SCHEMA_SWAGGER_JS_URL str -QUART_SCHEMA_SWAGGER_CSS_URL str -QUART_SCHEMA_REDOC_JS_URL str -QUART_SCHEMA_BY_ALIAS bool -QUART_SCHEMA_CONVERT_CASING bool -============================= ===== +================================== ===== +Configuration key type +---------------------------------- ----- +QUART_SCHEMA_CONVERSION_PREFERENCE str +QUART_SCHEMA_SWAGGER_JS_URL str +QUART_SCHEMA_SWAGGER_CSS_URL str +QUART_SCHEMA_REDOC_JS_URL str +QUART_SCHEMA_BY_ALIAS bool +QUART_SCHEMA_CONVERT_CASING bool +================================== ===== which allow the js and css for the documentation UI to be changed and configured and specifies that responses that are Pydantic models diff --git a/docs/how_to_guides/error_handling.rst b/docs/how_to_guides/error_handling.rst index 0c5fc7e..f4f66c0 100644 --- a/docs/how_to_guides/error_handling.rst +++ b/docs/how_to_guides/error_handling.rst @@ -15,7 +15,8 @@ an error handler, for example for a JSON error response, or if you prefer to let the requestor know exactly why the validation failed you can utilise the ``validation_error`` attribute which is a -either Pydantic ``ValidationError`` or a ``TypeError``, +either Pydantic ``ValidationError``, a msgspec ``ValidationError`` or +a ``TypeError``, .. code-block:: python @@ -23,11 +24,6 @@ either Pydantic ``ValidationError`` or a ``TypeError``, @app.errorhandler(RequestSchemaValidationError) async def handle_request_validation_error(error): - if isinstance(error.validation_error, TypeError): - return { - "errors": str(error.validation_error), - }, 400 - else: - return { - "errors": error.validation_error.json(), - }, 400 + return { + "errors": str(error.validation_error), + }, 400 diff --git a/docs/how_to_guides/headers_validation.rst b/docs/how_to_guides/headers_validation.rst index 4a35f24..cc603e1 100644 --- a/docs/how_to_guides/headers_validation.rst +++ b/docs/how_to_guides/headers_validation.rst @@ -12,21 +12,74 @@ Request headers Request headers can be validated against a schema you define by decorating the route handler, as so, -.. code-block:: python +.. tabs:: + + .. tab:: attrs + + .. code-block:: python + + from attrs import define + from quart_schema import validate_headers + + @define + class Headers: + x_required: str + x_optional: int | None = None + + @app.route("/") + @validate_headers(Headers) + async def index(headers: Headers): + ... + + .. tab:: dataclasses + + .. code-block:: python + + from dataclasses import dataclass + + from quart_schema import validate_headers + + @dataclass + class Headers: + x_required: str + x_optional: int | None = None + + @app.route("/") + @validate_headers(Headers) + async def index(headers: Headers): + ... + + .. tab:: msgspec + + .. code-block:: python + + from msgspec import Struct + from quart_schema import validate_headers + + class Headers(Struct): + x_required: str + x_optional: int | None = None + + @app.route("/") + @validate_headers(Headers) + async def index(headers: Headers): + ... + + .. tab:: pydantic - from dataclasses import dataclass + .. code-block:: python - from quart_schema import validate_headers + from pydantic import BaseModel + from quart_schema import validate_headers - @dataclass - class Headers: - x_required: str - x_optional: int | None = None + class Headers(BaseModel): + x_required: str + x_optional: int | None = None - @app.route("/") - @validate_headers(Headers) - async def index(headers: Headers): - ... + @app.route("/") + @validate_headers(Headers) + async def index(headers: Headers): + ... this will require the client adds a ``X-Required`` header to the request and optionally ``X-Optional`` of type int. @@ -65,22 +118,78 @@ Response headers Request headers can be validated alongside the response body bt decorating the route handler with a relevant schema, as so, -.. code-block:: python +.. tabs:: + + .. tab:: attrs + + .. code-block:: python + + from attrs import define + from quart_schema import validate_response + + @define + class Headers: + x_required: str + x_optional: int | None = None + + @app.route("/") + @validate_response(Body, 200, Headers) + async def index(): + ... + return body, 200, headers + + .. tab:: dataclasses + + .. code-block:: python + + from dataclasses import dataclass + + from quart_schema import validate_response + + @dataclass + class Headers: + x_required: str + x_optional: int | None = None + + @app.route("/") + @validate_response(Body, 200, Headers) + async def index(): + ... + return body, 200, headers + + .. tab:: msgspec + + .. code-block:: python + + from msgspec import Struct + from quart_schema import validate_response + + class Headers(Struct): + x_required: str + x_optional: int | None = None + + @app.route("/") + @validate_response(Body, 200, Headers) + async def index(): + ... + return body, 200, headers + + .. tab:: pydantic - from dataclasses import dataclass + .. code-block:: python - from quart_schema import validate_response + from pydantic import BaseModel + from quart_schema import validate_response - @dataclass - class Headers: - x_required: str - x_optional: int | None = None + class Headers(BaseModel): + x_required: str + x_optional: int | None = None - @app.route("/") - @validate_response(Body, 200, Headers) - async def index(): - ... - return body, 200, headers + @app.route("/") + @validate_response(Body, 200, Headers) + async def index(): + ... + return body, 200, headers this will require that the headers variable adds a ``X-Required`` header to the response and optionally ``X-Optional`` of type int. The diff --git a/docs/how_to_guides/querystring_validation.rst b/docs/how_to_guides/querystring_validation.rst index d0ebcec..0234347 100644 --- a/docs/how_to_guides/querystring_validation.rst +++ b/docs/how_to_guides/querystring_validation.rst @@ -7,21 +7,74 @@ it in a format you understand. This is done by validating them against a schema you define. Quart-Schema allows validation via decorating the route handler, as so, -.. code-block:: python +.. tabs:: - from dataclasses import dataclass + .. tab:: attrs - from quart_schema import validate_querystring + .. code-block:: python - @dataclass - class Query: - count_le: int | None = None - count_gt: int | None = None + from attrs import define + from quart_schema import validate_querystring - @app.route("/") - @validate_querystring(Query) - async def index(query_args: Query): - ... + @define + class Query: + count_le: int | None = None + count_gt: int | None = None + + @app.route("/") + @validate_querystring(Query) + async def index(query_args: Query): + ... + + .. tab:: dataclasses + + .. code-block:: python + + from dataclasses import dataclass + + from quart_schema import validate_querystring + + @dataclass + class Query: + count_le: int | None = None + count_gt: int | None = None + + @app.route("/") + @validate_querystring(Query) + async def index(query_args: Query): + ... + + .. tab:: msgspec + + .. code-block:: python + + from msgspec import Struct + from quart_schema import validate_querystring + + class Query(Struct): + count_le: int | None = None + count_gt: int | None = None + + @app.route("/") + @validate_querystring(Query) + async def index(query_args: Query): + ... + + .. tab:: pydantic + + .. code-block:: python + + from pydantic import BaseModel + from quart_schema import validate_querystring + + class Query(BaseModel): + count_le: int | None = None + count_gt: int | None = None + + @app.route("/") + @validate_querystring(Query) + async def index(query_args: Query): + ... this will allow the client to add a ``count_le``, ``count_gt``, or both parameters to the URL e,g. ``/?count_le=2&count_gt=0``. @@ -72,10 +125,10 @@ to a list using a ``BeforeValidator``, from quart_schema import validate_querystring def _to_list(value: str | list[str]) -> list[str]: - if isinstance(value, list): - return value - else: - return [value] + if isinstance(value, list): + return value + else: + return [value] class Query(BaseModel): keys: Annotated[Optional[List[str]], BeforeValidator(_to_list)] = Non @@ -84,3 +137,7 @@ to a list using a ``BeforeValidator``, @validate_querystring(Query) async def index(query_args: Query): ... + +.. warning:: + + This currently only works with Pydantic types and validation. diff --git a/docs/how_to_guides/request_validation.rst b/docs/how_to_guides/request_validation.rst index 22b3716..e645f47 100644 --- a/docs/how_to_guides/request_validation.rst +++ b/docs/how_to_guides/request_validation.rst @@ -7,21 +7,74 @@ request data is correct against a schema you define. Quart-Schema allows validation of JSON data via decorating the route handler, as so, -.. code-block:: python +.. tabs:: - from dataclasses import dataclass + .. tab:: attrs - from quart_schema import validate_request + .. code-block:: python - @dataclass - class Todo: - effort: int - task: str + from attrs import define + from quart_schema import validate_request - @app.route("/", methods=["POST']) - @validate_request(Todo) - async def index(data: Todo): - ... + @define + class Todo: + effort: int + task: str + + @app.route("/", methods=["POST']) + @validate_request(Todo) + async def index(data: Todo): + ... + + .. tab:: dataclasses + + .. code-block:: python + + from dataclasses import dataclass + + from quart_schema import validate_request + + @dataclass + class Todo: + effort: int + task: str + + @app.route("/", methods=["POST']) + @validate_request(Todo) + async def index(data: Todo): + ... + + .. tab:: msgspec + + .. code-block:: python + + from msgspec import Struct + from quart_schema import validate_request + + class Todo(Struct): + effort: int + task: str + + @app.route("/", methods=["POST']) + @validate_request(Todo) + async def index(data: Todo): + ... + + .. tab:: pydantic + + .. code-block:: python + + from pydantic import BaseModel + from quart_schema import validate_request + + class Todo(BaseModel): + effort: int + task: str + + @app.route("/", methods=["POST']) + @validate_request(Todo) + async def index(data: Todo): + ... this will expect the client to send a body with JSON structured to match the Todo class, for example, @@ -46,22 +99,74 @@ decorator assumes the request body is JSON encoded. If the request body is form (application/x-www-form-urlencoded) encoded the ``source`` argument can be changed to validate the form data, -.. code-block:: python +.. tabs:: - from dataclasses import dataclass + .. tab:: attrs - from quart_schema import DataSource, validate_request + .. code-block:: python - @dataclass - class Todo: - effort: int - task: str + from attrs import define + from quart_schema import DataSource, validate_request - @app.route("/", methods=["POST']) - @validate_request(Todo, source=DataSource.FORM) - async def index(data: Todo): - ... + @define + class Todo: + effort: int + task: str + + @app.route("/", methods=["POST']) + @validate_request(Todo, source=DataSource.FORM) + async def index(data: Todo): + ... + + .. tab:: dataclasses + + .. code-block:: python + + from dataclasses import dataclass + + from quart_schema import DataSource, validate_request + + @dataclass + class Todo: + effort: int + task: str + + @app.route("/", methods=["POST']) + @validate_request(Todo, source=DataSource.FORM) + async def index(data: Todo): + ... + + .. tab:: msgspec + + .. code-block:: python + from msgspec import Struct + from quart_schema import DataSource, validate_request + + class Todo(Struct): + effort: int + task: str + + @app.route("/", methods=["POST']) + @validate_request(Todo, source=DataSource.FORM) + async def index(data: Todo): + ... + + .. tab:: pydantic + + .. code-block:: python + + from pydantic import BaseModel + from quart_schema import DataSource, validate_request + + class Todo(BaseModel): + effort: int + task: str + + @app.route("/", methods=["POST']) + @validate_request(Todo, source=DataSource.FORM) + async def index(data: Todo): + ... with everything working as in the JSON example above. .. note:: @@ -93,3 +198,8 @@ the ``source`` argument must be changed to validate the form data, @validate_request(Upload, source=DataSource.FORM_MULTIPART) async def index(data: Upload): file_content = data.file.read() + + +.. warning:: + + This currently only works with Pydantic types and validation. diff --git a/docs/how_to_guides/response_validation.rst b/docs/how_to_guides/response_validation.rst index accf413..ac8c415 100644 --- a/docs/how_to_guides/response_validation.rst +++ b/docs/how_to_guides/response_validation.rst @@ -10,21 +10,74 @@ response data is correct against a schema you define. Quart-Schema allows validation of JSON data via decorating the route handler, as so, -.. code-block:: python +.. tabs:: - from dataclasses import dataclass + .. tab:: attrs - from quart_schema import validate_response + .. code-block:: python - @dataclass - class Todo: - effort: int - task: str + from attrs import define + from quart_schema import validate_response - @app.route("/") - @validate_response(Todo) - async def index(): - return data + @define + class Todo: + effort: int + task: str + + @app.route("/") + @validate_response(Todo) + async def index(): + return data + + .. tab:: dataclasses + + .. code-block:: python + + from dataclasses import dataclass + + from quart_schema import validate_response + + @dataclass + class Todo: + effort: int + task: str + + @app.route("/") + @validate_response(Todo) + async def index(): + return data + + .. tab:: attrs + + .. code-block:: python + + from msgspec import Struct + from quart_schema import validate_response + + class Todo(Struct): + effort: int + task: str + + @app.route("/") + @validate_response(Todo) + async def index(): + return data + + .. tab:: pydantic + + .. code-block:: python + + from pydantic import BaseModel + from quart_schema import validate_response + + class Todo(BaseModel): + effort: int + task: str + + @app.route("/") + @validate_response(Todo) + async def index(): + return data will ensure that ``data`` represents or is a ``Todo`` object, e.g. these responses are allowed. Note that the typical Quart response diff --git a/docs/how_to_guides/testing.rst b/docs/how_to_guides/testing.rst index 4a10e75..9ef786e 100644 --- a/docs/how_to_guides/testing.rst +++ b/docs/how_to_guides/testing.rst @@ -6,20 +6,81 @@ normal Quart routes (everything works the same). In addition Quart-Schema allows Pydantic models, or dataclasses to be sent via the test client as ``json`` or ``form`` data, for example, -.. code-block:: python +.. tabs:: - @dataclass - class DCDetails: - name: str - age: int | None = None + .. tab:: attrs - @pytest.mark.asyncio - async def test_send_dataclass() -> None: - ... - test_client = app.test_client() - response = await test_client.post("/", json=DCDetails(name="bob", age=2)) - # Or - response = await test_client.post("/", form=DCDetails(name="bob", age=2)) + .. code-block:: python + + from attrs import define + + @define + class Details: + name: str + age: int | None = None + + @pytest.mark.asyncio + async def test_send() -> None: + ... + test_client = app.test_client() + response = await test_client.post("/", json=Details(name="bob", age=2)) + # Or + response = await test_client.post("/", form=Details(name="bob", age=2)) + + .. tab:: dataclasses + + .. code-block:: python + + from dataclasses import dataclass + + @dataclass + class Details: + name: str + age: int | None = None + + @pytest.mark.asyncio + async def test_send() -> None: + ... + test_client = app.test_client() + response = await test_client.post("/", json=Details(name="bob", age=2)) + # Or + response = await test_client.post("/", form=Details(name="bob", age=2)) + + .. tab:: attrs + + .. code-block:: python + + from msgspec import Struct + + class Details(Struct): + name: str + age: int | None = None + + @pytest.mark.asyncio + async def test_send() -> None: + ... + test_client = app.test_client() + response = await test_client.post("/", json=Details(name="bob", age=2)) + # Or + response = await test_client.post("/", form=Details(name="bob", age=2)) + + .. tab:: pydantic + + .. code-block:: python + + from pydantic import BaseModel + + class Details(BaseModel): + name: str + age: int | None = None + + @pytest.mark.asyncio + async def test_send() -> None: + ... + test_client = app.test_client() + response = await test_client.post("/", json=Details(name="bob", age=2)) + # Or + response = await test_client.post("/", form=Details(name="bob", age=2)) Hypothesis testing @@ -40,19 +101,19 @@ example, # Other imports not shown @dataclass - class DCDetails: + class Details: name: str age: int | None = None - @given(st.builds(DCDetails)) + @given(st.builds(Details)) @pytest.mark.asyncio - async def test_index(data: DCDetails) -> None: + async def test_index(data: Details) -> None: app = Quart(__name__) QuartSchema(app) @app.route("/", methods=["POST"]) - @validate_request(DCDetails) - async def index(data: DCDetails) -> Any: + @validate_request(Details) + async def index(data: Details) -> Any: return data test_client = app.test_client() diff --git a/docs/index.rst b/docs/index.rst index 37b75a6..c238a7f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,7 +15,8 @@ documentation. Using Quart-Schema you can, * generate OpenAPI documentation from the validation, * on top of everything Quart can do. -with Quart-Schema's validation based on the excellent `pydantic +with Quart-Schema's validation based on either the `msgspec +`_ or `pydantic `_ library. If you are, @@ -23,8 +24,11 @@ If you are, * new to Quart-Schema then try the :ref:`quickstart`, * new to Quart then try the `Quart documentation `_, + * new to msgspec then try the `msgspec documentation + `_, * new to Pydantic then try the `pydantic documentation `_, + * unsure which library to use, try the :ref:`validation_library`. Quart-Schema is developed on `GitHub `_. If you come across an diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index fc273f4..27c6e69 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -7,3 +7,4 @@ Tutorials installation.rst quickstart.rst + validation_library.rst diff --git a/docs/tutorials/installation.rst b/docs/tutorials/installation.rst index 2784b39..00563ed 100644 --- a/docs/tutorials/installation.rst +++ b/docs/tutorials/installation.rst @@ -4,11 +4,14 @@ Installation ============ Quart-Schema is only compatible with Python 3.8 or higher and can be -installed using pip or your favorite python package manager. +installed using pip or your favorite python package manager. At +installation please specify if you want to use Quart-Schema with +msgspec or pydantic. For example when using pip choose one of, .. code-block:: sh - pip install quart-schema + pip install quart-schema[msgspec] + pip install quart-schema[pydantic] Installing quart-schema will install Quart if it is not present in your environment. diff --git a/docs/tutorials/validation_library.rst b/docs/tutorials/validation_library.rst new file mode 100644 index 0000000..4384291 --- /dev/null +++ b/docs/tutorials/validation_library.rst @@ -0,0 +1,106 @@ +.. _validation_library: + +Validation library choice +========================= + +Quart-Schema is agnostic to your choice of validation library with +msgspec and Pydantic both supported. This choice must be made at +installation, with either the ``msgspec`` or ``pydantic`` extra used. + +.. note:: + + If you install both msgspec and Pydantic you can control which is + used for builtin types by setting the + ``QUART_SCHEMA_CONVERSION_PREFERENCE`` to either ``msgspec`` or + ``pydantic``. + +This documentation will show examples for both msgspec and Pydantic. + +msgspec +------- + +If you choose msgspec you can contruct the validation classes as +dataclasses, attrs definitions, or msgspec structs. + +.. tabs:: + + .. tab:: attrs + + .. code-block:: python + + from attrs import define + + @define + class Person: + name: str + + .. tab:: dataclasses + + .. code-block:: python + + from dataclasses import dataclass + + @dataclass + class Person: + name: str + + .. tab:: msgspec + + .. code-block:: python + + from msgspec import Struct + + class Person(Struct): + name: str + +Pydantic +-------- + +If you choose Pydantic you can contruct the validation classes as +dataclasses, Pydantic dataclasses, or BaseModels. + +.. tabs:: + + .. tab:: dataclasses + + .. code-block:: python + + from dataclasses import dataclass + + @dataclass + class Person: + name: str + + .. tab:: Pydantic dataclasses + + .. code-block:: python + + from pydantic.dataclasses import dataclass + + @dataclass + class Person: + name: str + + .. tab:: Pydantic BaseModel + + .. code-block:: python + + from pydantic import BaseModel + + class Person(BaseModel): + name: str + +Lists +----- + +Note that lists are valid validation models i.e. the following is +valid for any of the above ``Person`` defintions, + +.. code-block:: python + + from typing import List + + @validate_request(List[Person]) + @app.post("/") + async def index(): + ... diff --git a/pyproject.toml b/pyproject.toml index af3796a..044d10e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,18 +58,22 @@ warn_unused_configs = true warn_unused_ignores = true [tool.poetry.dependencies] +msgspec = { version = ">=0.18", optional = true } pydata_sphinx_theme = { version = "*", optional = true } pyhumps = ">=1.6.1" python = ">=3.8" -pydantic=">=2" +pydantic = { version = ">=2", optional = true } quart = ">=0.19.0" +sphinx-tabs = { version = ">=3.4.4", optional = true } typing_extensions = { version = "*", python = "<3.9" } [tool.poetry.dev-dependencies] tox = "*" [tool.poetry.extras] -docs = ["pydata_sphinx_theme"] +docs = ["pydata_sphinx_theme", "sphinx-tabs"] +msgspec = ["msgspec"] +pydantic = ["pydantic"] [tool.pytest.ini_options] addopts = "--no-cov-on-fail --showlocals --strict-markers" diff --git a/src/quart_schema/conversion.py b/src/quart_schema/conversion.py index 0805d27..7552a9a 100644 --- a/src/quart_schema/conversion.py +++ b/src/quart_schema/conversion.py @@ -1,34 +1,96 @@ from __future__ import annotations -from dataclasses import asdict, fields, is_dataclass -from typing import Optional, Type, TypeVar, Union +from dataclasses import fields, is_dataclass +from typing import Any, Optional, Type, TypeVar, Union import humps -from pydantic import BaseModel, RootModel, TypeAdapter, ValidationError -from pydantic.dataclasses import is_pydantic_dataclass from quart import current_app from quart.typing import HeadersValue, ResponseReturnValue as QuartResponseReturnValue, StatusCode from werkzeug.datastructures import Headers +from werkzeug.exceptions import HTTPException -from .typing import BM, DC, Model, ResponseReturnValue, ResponseValue +from .typing import Model, ResponseReturnValue, ResponseValue -REF_TEMPLATE = "#/components/schemas/{model}" +try: + from pydantic import ( + BaseModel, + RootModel, + TypeAdapter, + ValidationError as PydanticValidationError, + ) + from pydantic.dataclasses import is_pydantic_dataclass +except ImportError: + PYDANTIC_INSTALLED = False + + class BaseModel: # type: ignore + pass + + class RootModel: # type: ignore + pass + + class TypeAdapter: # type: ignore + pass + + def is_pydantic_dataclass(object_: Any) -> bool: # type: ignore + return False + + class PydanticValidationError(Exception): # type: ignore + pass + +else: + PYDANTIC_INSTALLED = True + + +try: + from attrs import fields as attrs_fields, has as is_attrs + from msgspec import convert, Struct, to_builtins, ValidationError as MsgSpecValidationError + from msgspec.json import schema_components +except ImportError: + MSGSPEC_INSTALLED = False + + class Struct: # type: ignore + pass + + def is_attrs(object_: Any) -> bool: # type: ignore + return False -T = TypeVar("T") + def convert(object_: Any, type_: Any) -> Any: # type: ignore + raise RuntimeError("Cannot convert, msgspec not installed") + def to_builtins(object_: Any) -> Any: # type: ignore + return object_ -def convert_response_return_value(result: ResponseReturnValue) -> QuartResponseReturnValue: + class MsgSpecValidationError(Exception): # type: ignore + pass + +else: + MSGSPEC_INSTALLED = True + + +PYDANTIC_REF_TEMPLATE = "#/components/schemas/{model}" +MSGSPEC_REF_TEMPLATE = "#/components/schemas/{name}" + +T = TypeVar("T", bound=Model) + + +def convert_response_return_value( + result: ResponseReturnValue | HTTPException, +) -> QuartResponseReturnValue | HTTPException: value: ResponseValue headers: Optional[HeadersValue] = None status: Optional[StatusCode] = None - if isinstance(result, tuple): + if isinstance(result, HTTPException): + return result + elif isinstance(result, tuple): if len(result) == 3: value, status, headers = result # type: ignore elif len(result) == 2: value, status_or_headers = result if isinstance(status_or_headers, int): status = status_or_headers + else: + headers = status_or_headers # type: ignore else: value = result @@ -36,11 +98,13 @@ def convert_response_return_value(result: ResponseReturnValue) -> QuartResponseR value, camelize=current_app.config["QUART_SCHEMA_CONVERT_CASING"], by_alias=current_app.config["QUART_SCHEMA_BY_ALIAS"], + preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) headers = model_dump( headers, # type: ignore kebabize=True, by_alias=current_app.config["QUART_SCHEMA_BY_ALIAS"], + preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) new_result: ResponseReturnValue @@ -59,14 +123,31 @@ def convert_response_return_value(result: ResponseReturnValue) -> QuartResponseR def model_dump( - raw: ResponseValue, *, by_alias: bool, camelize: bool = False, kebabize: bool = False -) -> dict: + raw: ResponseValue, + *, + by_alias: bool, + camelize: bool = False, + kebabize: bool = False, + preference: Optional[str] = None, +) -> dict | list: if is_pydantic_dataclass(raw): # type: ignore value = RootModel[type(raw)](raw).model_dump() # type: ignore - elif is_dataclass(raw): - value = asdict(raw) # type: ignore elif isinstance(raw, BaseModel): value = raw.model_dump(by_alias=by_alias) + elif isinstance(raw, Struct) or is_attrs(raw): # type: ignore + value = to_builtins(raw) + elif ( + (isinstance(raw, (list, dict)) or is_dataclass(raw)) + and PYDANTIC_INSTALLED + and preference != "msgspec" + ): + value = TypeAdapter(type(raw)).dump_python(raw) + elif ( + (isinstance(raw, (list, dict)) or is_dataclass(raw)) + and MSGSPEC_INSTALLED + and preference != "pydantic" + ): + value = to_builtins(raw) else: return raw # type: ignore @@ -79,30 +160,76 @@ def model_dump( def model_load( - data: dict, model_class: Model, exception_class: Type[Exception], *, decamelize: bool = False -) -> Union[BM, DC]: + data: dict, + model_class: Type[T], + exception_class: Type[Exception], + *, + decamelize: bool = False, + preference: Optional[str] = None, +) -> T: if decamelize: data = humps.decamelize(data) try: - return TypeAdapter(model_class).validate_python(data) - except (TypeError, ValidationError, ValueError) as error: + if ( + is_pydantic_dataclass(model_class) + or issubclass(model_class, BaseModel) + or ( + (isinstance(model_class, (list, dict)) or is_dataclass(model_class)) + and PYDANTIC_INSTALLED + and preference != "msgspec" + ) + ): + return TypeAdapter(model_class).validate_python(data) # type: ignore + elif ( + issubclass(model_class, Struct) + or is_attrs(model_class) + or ( + (isinstance(model_class, (list, dict)) or is_dataclass(model_class)) + and MSGSPEC_INSTALLED + and preference != "pydantic" + ) + ): + return convert(data, model_class, strict=False) # type: ignore + else: + raise TypeError(f"Cannot load {model_class}") + except (TypeError, MsgSpecValidationError, PydanticValidationError, ValueError) as error: raise exception_class(error) -def model_schema(model_class: Model) -> dict: - return TypeAdapter(model_class).json_schema(ref_template=REF_TEMPLATE) +def model_schema(model_class: Type[Model], *, preference: Optional[str] = None) -> dict: + if ( + is_pydantic_dataclass(model_class) + or issubclass(model_class, BaseModel) + or (isinstance(model_class, (list, dict)) and preference != "msgspec") + or (is_dataclass(model_class) and preference != "msgspec") + ): + return TypeAdapter(model_class).json_schema(ref_template=PYDANTIC_REF_TEMPLATE) + elif ( + issubclass(model_class, Struct) + or is_attrs(model_class) + or (isinstance(model_class, (list, dict)) and preference != "pydantic") + or (is_dataclass(model_class) and preference != "pydantic") + ): + _, schema = schema_components([model_class], ref_template=MSGSPEC_REF_TEMPLATE) + return list(schema.values())[0] + else: + raise TypeError(f"Cannot create schema for {model_class}") def convert_headers( raw: Union[Headers, dict], model_class: Type[T], exception_class: Type[Exception] ) -> T: if is_pydantic_dataclass(model_class): - fields_ = model_class.__pydantic_fields__.keys() + fields_ = set(model_class.__pydantic_fields__.keys()) elif is_dataclass(model_class): - fields_ = {field.name for field in fields(model_class)} # type: ignore + fields_ = {field.name for field in fields(model_class)} elif issubclass(model_class, BaseModel): - fields_ = model_class.model_fields.keys() + fields_ = set(model_class.model_fields.keys()) + elif is_attrs(model_class): + fields_ = {field.name for field in attrs_fields(model_class)} + elif issubclass(model_class, Struct): + fields_ = set(model_class.__struct_fields__) else: raise TypeError(f"Cannot convert to {model_class}") @@ -117,5 +244,5 @@ def convert_headers( try: return model_class(**result) - except (TypeError, ValidationError, ValueError) as error: + except (TypeError, MsgSpecValidationError, ValueError) as error: raise exception_class(error) diff --git a/src/quart_schema/extension.py b/src/quart_schema/extension.py index f084e1e..d94280b 100644 --- a/src/quart_schema/extension.py +++ b/src/quart_schema/extension.py @@ -5,12 +5,11 @@ from collections import defaultdict from functools import wraps from types import new_class -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union import click import humps -from pydantic_core import to_jsonable_python -from quart import current_app, Quart, render_template_string, ResponseReturnValue +from quart import current_app, jsonify, Quart, render_template_string, ResponseReturnValue from quart.cli import pass_script_info, ScriptInfo from quart.json.provider import DefaultJSONProvider from quart.typing import ResponseReturnValue as QuartResponseReturnValue @@ -38,6 +37,12 @@ QUART_SCHEMA_RESPONSE_ATTRIBUTE, ) +try: + from pydantic_core import to_jsonable_python +except ImportError: + from msgspec import to_builtins as to_jsonable_python # type: ignore + + SecurityScheme = Union[ APIKeySecurityScheme, HttpSecurityScheme, @@ -179,12 +184,14 @@ def __init__( security_schemes: Optional[Dict[str, SecuritySchemeInput]] = None, security: Optional[List[Dict[str, List[str]]]] = None, external_docs: Optional[Union[ExternalDocumentation, dict]] = None, + conversion_preference: Literal["msgspec", "pydantic", None] = None, ) -> None: self.openapi_path = openapi_path self.redoc_ui_path = redoc_ui_path self.swagger_ui_path = swagger_ui_path self.convert_casing = convert_casing + self.conversion_preference = conversion_preference self.info: Optional[Info] = None if info is not None: @@ -258,6 +265,7 @@ def init_app(self, app: Quart) -> None: ) app.config.setdefault("QUART_SCHEMA_BY_ALIAS", False) app.config.setdefault("QUART_SCHEMA_CONVERT_CASING", self.convert_casing) + app.config.setdefault("QUART_SCHEMA_CONVERSION_PREFERENCE", self.conversion_preference) if self.openapi_path is not None: hide(app.send_static_file.__func__) # type: ignore app.add_url_rule(self.openapi_path, "openapi", self.openapi) @@ -270,7 +278,7 @@ def init_app(self, app: Quart) -> None: @hide async def openapi(self) -> ResponseReturnValue: - return _build_openapi_schema(current_app, self) + return jsonify(_build_openapi_schema(current_app, self)) @hide async def swagger_ui(self) -> str: @@ -430,7 +438,9 @@ def _build_path(func: Callable, rule: Rule, app: Quart) -> Tuple[dict, dict]: response_models = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {}) for status_code in response_models.keys(): model_class, headers_model_class = response_models[status_code] - schema = model_schema(model_class) + schema = model_schema( + model_class, preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"] + ) definitions, schema = _split_convert_definitions( schema, app.config["QUART_SCHEMA_CONVERT_CASING"] ) @@ -447,7 +457,10 @@ def _build_path(func: Callable, rule: Rule, app: Quart) -> Tuple[dict, dict]: response_object["description"] = inspect.getdoc(model_class) if headers_model_class is not None: - schema = model_schema(headers_model_class) + schema = model_schema( + headers_model_class, + preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], + ) definitions, schema = _split_definitions(schema) components.update(definitions) response_object["content"]["headers"] = { # type: ignore @@ -460,7 +473,9 @@ def _build_path(func: Callable, rule: Rule, app: Quart) -> Tuple[dict, dict]: request_data = getattr(func, QUART_SCHEMA_REQUEST_ATTRIBUTE, None) if request_data is not None: - schema = model_schema(request_data[0]) + schema = model_schema( + request_data[0], preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"] + ) definitions, schema = _split_convert_definitions( schema, app.config["QUART_SCHEMA_CONVERT_CASING"] ) @@ -483,7 +498,9 @@ def _build_path(func: Callable, rule: Rule, app: Quart) -> Tuple[dict, dict]: querystring_model = getattr(func, QUART_SCHEMA_QUERYSTRING_ATTRIBUTE, None) if querystring_model is not None: - schema = model_schema(querystring_model) + schema = model_schema( + querystring_model, preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"] + ) definitions, schema = _split_convert_definitions( schema, app.config["QUART_SCHEMA_CONVERT_CASING"] ) @@ -499,7 +516,9 @@ def _build_path(func: Callable, rule: Rule, app: Quart) -> Tuple[dict, dict]: headers_model = getattr(func, QUART_SCHEMA_HEADERS_ATTRIBUTE, None) if headers_model is not None: - schema = model_schema(headers_model) + schema = model_schema( + headers_model, preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"] + ) definitions, schema = _split_definitions(schema) components.update(definitions) for name, type_ in schema["properties"].items(): diff --git a/src/quart_schema/mixins.py b/src/quart_schema/mixins.py index e9463e7..3fc01ae 100644 --- a/src/quart_schema/mixins.py +++ b/src/quart_schema/mixins.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, AnyStr, Dict, Optional, overload, Tuple, Type, Union +from typing import Any, AnyStr, Dict, Optional, Tuple, Type, Union from quart import current_app, Response from quart.datastructures import FileStorage @@ -8,7 +8,7 @@ from werkzeug.datastructures import Authorization, Headers from .conversion import model_dump, model_load -from .typing import BM, DC, TestClientProtocol, WebsocketProtocol +from .typing import Model, TestClientProtocol, WebsocketProtocol class SchemaValidationError(Exception): @@ -18,36 +18,31 @@ def __init__(self, validation_error: Optional[Exception] = None) -> None: class WebsocketMixin: - @overload - async def receive_as(self: WebsocketProtocol, model_class: Type[BM]) -> BM: - ... - - @overload - async def receive_as(self: WebsocketProtocol, model_class: Type[DC]) -> DC: - ... - - async def receive_as( # type: ignore[misc] - self: WebsocketProtocol, model_class: Union[Type[BM], Type[DC]] - ) -> Union[BM, DC]: + async def receive_as(self: WebsocketProtocol, model_class: Type[Model]) -> Model: data = await self.receive_json() return model_load( data, model_class, SchemaValidationError, decamelize=current_app.config["QUART_SCHEMA_CONVERT_CASING"], + preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) - async def send_as( - self: WebsocketProtocol, value: Any, model_class: Union[Type[BM], Type[DC]] - ) -> None: + async def send_as(self: WebsocketProtocol, value: Any, model_class: Type[Model]) -> None: if type(value) != model_class: # noqa: E721 - value = model_load(value, model_class, SchemaValidationError) + value = model_load( + value, + model_class, + SchemaValidationError, + preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], + ) data = model_dump( value, camelize=current_app.config["QUART_SCHEMA_CONVERT_CASING"], by_alias=current_app.config["QUART_SCHEMA_BY_ALIAS"], + preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) - await self.send_json(data) + await self.send_json(data) # type: ignore class TestClientMixin: @@ -73,20 +68,22 @@ async def _make_request( json, camelize=self.app.config["QUART_SCHEMA_CONVERT_CASING"], by_alias=self.app.config["QUART_SCHEMA_BY_ALIAS"], + preference=self.app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) if form is not None: - form = model_dump( + form = model_dump( # type: ignore form, camelize=self.app.config["QUART_SCHEMA_CONVERT_CASING"], by_alias=self.app.config["QUART_SCHEMA_BY_ALIAS"], + preference=self.app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) - if query_string is not None: - query_string = model_dump( + query_string = model_dump( # type: ignore query_string, camelize=self.app.config["QUART_SCHEMA_CONVERT_CASING"], by_alias=self.app.config["QUART_SCHEMA_BY_ALIAS"], + preference=self.app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) return await super()._make_request( # type: ignore diff --git a/src/quart_schema/openapi.py b/src/quart_schema/openapi.py index 34c74ae..d0dfc83 100644 --- a/src/quart_schema/openapi.py +++ b/src/quart_schema/openapi.py @@ -1,13 +1,8 @@ from dataclasses import dataclass, field, fields -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional import humps -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore - class _SchemaBase: def schema(self, *, camelize: bool = False) -> Dict: diff --git a/src/quart_schema/typing.py b/src/quart_schema/typing.py index 9c39563..00b0ba5 100644 --- a/src/quart_schema/typing.py +++ b/src/quart_schema/typing.py @@ -1,8 +1,7 @@ from __future__ import annotations -from typing import Any, AnyStr, Dict, Optional, Tuple, Type, TYPE_CHECKING, TypeVar, Union +from typing import Any, AnyStr, Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union -from pydantic import BaseModel from quart import Quart from quart.datastructures import FileStorage from quart.typing import ( @@ -14,20 +13,28 @@ from quart.wrappers import Response from werkzeug.datastructures import Headers -if TYPE_CHECKING: - from pydantic.dataclasses import Dataclass - try: from typing import Protocol except ImportError: from typing_extensions import Protocol # type: ignore +if TYPE_CHECKING: + from attrs import AttrsInstance + from msgspec import Struct + from pydantic import BaseModel + from pydantic.dataclasses import Dataclass -Model = Union[Type[BaseModel], Type["Dataclass"], Type] -PydanticModel = Union[Type[BaseModel], Type["Dataclass"]] -ResponseValue = Union[QuartResponseValue, PydanticModel] -HeadersValue = Union[QuartHeadersValue, PydanticModel] +class DataclassProtocol(Protocol): + __dataclass_fields__: Dict + __dataclass_params__: Dict + __post_init__: Optional[Callable] + + +ModelTypes = Union["AttrsInstance", "BaseModel", "Dataclass", "DataclassProtocol", "Struct"] +Model = Union[ModelTypes, List[ModelTypes], Dict[str, ModelTypes]] +ResponseValue = Union[QuartResponseValue, Type[Model]] +HeadersValue = Union[QuartHeadersValue, Model] ResponseReturnValue = Union[ QuartResponseReturnValue, @@ -65,7 +72,3 @@ async def _make_request( scope_base: Optional[dict], ) -> Response: ... - - -BM = TypeVar("BM", bound=BaseModel) -DC = TypeVar("DC", bound="Dataclass") diff --git a/src/quart_schema/validation.py b/src/quart_schema/validation.py index d2b27da..9d14e80 100644 --- a/src/quart_schema/validation.py +++ b/src/quart_schema/validation.py @@ -2,7 +2,7 @@ from enum import auto, Enum from functools import wraps -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple, Type from quart import current_app, request, Response from werkzeug.exceptions import BadRequest @@ -50,7 +50,7 @@ class DataSource(Enum): JSON = auto() -def validate_querystring(model_class: Model) -> Callable: +def validate_querystring(model_class: Type[Model]) -> Callable: """Validate the request querystring arguments. This ensures that the query string arguments can be converted to @@ -74,11 +74,12 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: else request.args[key] for key in request.args } - model = model_load( # type: ignore + model = model_load( request_args, model_class, QuerystringValidationError, decamelize=current_app.config["QUART_SCHEMA_CONVERT_CASING"], + preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) return await current_app.ensure_async(func)(*args, query_args=model, **kwargs) @@ -87,7 +88,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator -def validate_headers(model_class: Model) -> Callable: +def validate_headers(model_class: Type[Model]) -> Callable: """Validate the request headers. This ensures that the headers can be converted to the @@ -115,7 +116,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: def validate_request( - model_class: Model, + model_class: Type[Model], *, source: DataSource = DataSource.JSON, ) -> Callable: @@ -151,11 +152,12 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: else: data[key] = files[key] - model = model_load( # type: ignore + model = model_load( data, model_class, RequestSchemaValidationError, decamelize=current_app.config["QUART_SCHEMA_CONVERT_CASING"], + preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) return await current_app.ensure_async(func)(*args, data=model, **kwargs) @@ -165,9 +167,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: def validate_response( - model_class: Model, + model_class: Type[Model], status_code: int = 200, - headers_model_class: Optional[Model] = None, + headers_model_class: Optional[Type[Model]] = None, ) -> Callable: """Validate the response data. @@ -225,6 +227,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: value, # type: ignore model_class, ResponseSchemaValidationError, + preference=current_app.config["QUART_SCHEMA_CONVERSION_PREFERENCE"], ) if headers_model_class is not None: @@ -250,11 +253,11 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: def validate( *, - querystring: Optional[Model] = None, - request: Optional[Model] = None, + querystring: Optional[Type[Model]] = None, + request: Optional[Type[Model]] = None, request_source: DataSource = DataSource.JSON, - headers: Optional[Model] = None, - responses: Dict[int, Tuple[Model, Optional[Model]]], + headers: Optional[Type[Model]] = None, + responses: Dict[int, Tuple[Type[Model], Optional[Type[Model]]]], ) -> Callable: """Validate the route. diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..7dec840 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from typing import Optional + +from attrs import define +from msgspec import Struct +from pydantic import BaseModel +from pydantic.dataclasses import dataclass as pydantic_dataclass + + +@define +class ADetails: + name: str + age: Optional[int] = None + + +class MDetails(Struct): + name: str + age: Optional[int] = None + + +@dataclass +class DCDetails: + name: str + age: Optional[int] = None + + +class PyDetails(BaseModel): + name: str + age: Optional[int] = None + + +@pydantic_dataclass +class PyDCDetails: + name: str + age: Optional[int] = None diff --git a/tests/test_basic.py b/tests/test_basic.py index 3373a5e..f80e6d3 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,36 +1,19 @@ -from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import Type, Union from uuid import UUID import pytest from pydantic import BaseModel -from pydantic.dataclasses import dataclass as pydantic_dataclass from quart import Quart, ResponseReturnValue from quart_schema import QuartSchema -from quart_schema.typing import PydanticModel +from .helpers import ADetails, DCDetails, MDetails, PyDCDetails, PyDetails -@dataclass -class DCDetails: - name: str - age: Optional[int] = None - - -class Details(BaseModel): - name: str - age: Optional[int] - - -@pydantic_dataclass -class PyDCDetails: - name: str - age: Optional[int] = None - - -@pytest.mark.parametrize("type_", [DCDetails, Details, PyDCDetails]) -async def test_make_response(type_: PydanticModel) -> None: +@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]) +async def test_make_response( + type_: Type[Union[ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]] +) -> None: app = Quart(__name__) QuartSchema(app) @@ -49,11 +32,13 @@ async def test_make_response_no_model() -> None: @app.route("/") async def index() -> ResponseReturnValue: - return {"name": "bob", "age": 2}, {"Content-Type": "application/json"} + return {"name": "bob", "age": 2}, {"Content-Type": "application/json", "X-1": "2"} test_client = app.test_client() response = await test_client.get("/") assert (await response.get_json()) == {"name": "bob", "age": 2} + assert response.headers["Content-Type"] == "application/json" + assert response.headers["X-1"] == "2" class PydanticEncoded(BaseModel): diff --git a/tests/test_conversion.py b/tests/test_conversion.py new file mode 100644 index 0000000..e32f7e6 --- /dev/null +++ b/tests/test_conversion.py @@ -0,0 +1,129 @@ +from dataclasses import dataclass +from typing import Type, Union + +import pytest +from attrs import define +from msgspec import Struct +from pydantic import BaseModel +from pydantic.dataclasses import dataclass as pydantic_dataclass + +from quart_schema.conversion import convert_headers, model_dump, model_load, model_schema +from .helpers import ADetails, DCDetails, MDetails, PyDCDetails, PyDetails + + +class ValidationError(Exception): + pass + + +@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]) +def test_model_dump( + type_: Type[Union[ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]] +) -> None: + assert model_dump(type_(name="bob", age=2), by_alias=False) == { # type: ignore + "name": "bob", + "age": 2, + } + + +@pytest.mark.parametrize( + "type_, preference", + [ + (ADetails, "msgspec"), + (DCDetails, "msgspec"), + (DCDetails, "pydantic"), + (MDetails, "msgspec"), + (PyDetails, "pydantic"), + (PyDCDetails, "pydantic"), + ], +) +def test_model_dump_list( + type_: Type[Union[ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]], + preference: str, +) -> None: + assert model_dump( + [type_(name="bob", age=2), type_(name="jim", age=3)], by_alias=False, preference=preference + ) == [{"name": "bob", "age": 2}, {"name": "jim", "age": 3}] + + +@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]) +def test_model_load( + type_: Type[Union[ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]] +) -> None: + assert model_load({"name": "bob", "age": 2}, type_, exception_class=ValidationError) == type_( + name="bob", age=2 + ) + + +@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]) +def test_model_load_error( + type_: Type[Union[ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]] +) -> None: + with pytest.raises(ValidationError): + model_load({"name": "bob", "age": "two"}, type_, exception_class=ValidationError) + + +@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails]) +def test_model_schema_msgspec(type_: Type[Union[ADetails, DCDetails, MDetails]]) -> None: + assert model_schema(type_, preference="msgspec") == { + "title": type_.__name__, + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None}, + }, + "required": ["name"], + } + + +@pytest.mark.parametrize("type_", [DCDetails, PyDetails, PyDCDetails]) +def test_model_schema_pydantic(type_: Type[Union[DCDetails, PyDetails, PyDCDetails]]) -> None: + assert model_schema(type_, preference="pydantic") == { + "properties": { + "name": {"title": "Name", "type": "string"}, + "age": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, + "title": "Age", + }, + }, + "required": ["name"], + "title": type_.__name__, + "type": "object", + } + + +@define +class AHeaders: + x_info: str + + +class MHeaders(Struct): + x_info: str + + +@dataclass +class DCHeaders: + x_info: str + + +class PyHeaders(BaseModel): + x_info: str + + +@pydantic_dataclass +class PyDCHeaders: + x_info: str + + +@pytest.mark.parametrize("type_", [AHeaders, DCHeaders, MHeaders, PyHeaders, PyDCHeaders]) +def test_convert_headers( + type_: Type[Union[AHeaders, DCHeaders, MHeaders, PyHeaders, PyDCHeaders]], +) -> None: + convert_headers( + { + "X-Info": "ABC", + "Other": "2", + }, + type_, + exception_class=ValidationError, + ) == type_(x_info="ABC") diff --git a/tests/test_openapi.py b/tests/test_openapi.py index c640ec5..b654a84 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -1,5 +1,6 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type +import pytest from pydantic import Field from pydantic.dataclasses import dataclass from quart import Quart @@ -14,6 +15,8 @@ validate_request, validate_response, ) +from quart_schema.typing import Model +from .helpers import ADetails, DCDetails, MDetails, PyDCDetails, PyDetails @dataclass @@ -21,12 +24,6 @@ class QueryItem: count_le: Optional[int] = Field(description="count_le description") -@dataclass -class Details: - name: str - age: Optional[int] = None - - @dataclass class Result: """Result""" @@ -41,7 +38,17 @@ class Headers: ) -async def test_openapi() -> None: +@pytest.mark.parametrize( + "type_, titles", + [ + (ADetails, False), + (DCDetails, True), + (MDetails, False), + (PyDetails, True), + (PyDCDetails, True), + ], +) +async def test_openapi(type_: Type[Model], titles: bool) -> None: app = Quart(__name__) QuartSchema(app) @@ -62,7 +69,7 @@ async def read_item() -> Tuple[Result, int, Headers]: return Result(name="bob"), 200, Headers(x_name="jeff") @app.post("/") - @validate_request(Details) + @validate_request(type_) @validate_response(Result, 201, Headers) @operation_id("make_item") @deprecate() @@ -76,9 +83,11 @@ async def ws() -> None: test_client = app.test_client() response = await test_client.get("/openapi.json") - assert (await response.get_json()) == { + result = await response.get_json() + + expected = { "components": {"schemas": {}}, - "info": {"title": "test_openapi", "version": "0.1.0"}, + "info": {"title": "tests.test_openapi", "version": "0.1.0"}, "openapi": "3.1.0", "paths": { "/": { @@ -156,7 +165,7 @@ async def ws() -> None: "name": {"title": "Name", "type": "string"}, }, "required": ["name"], - "title": "Details", + "title": type_.__name__, "type": "object", } } @@ -192,6 +201,14 @@ async def ws() -> None: } }, } + if not titles: + del expected["paths"]["/"]["post"]["requestBody"]["content"][ # type: ignore + "application/json" + ]["schema"]["properties"]["name"]["title"] + del expected["paths"]["/"]["post"]["requestBody"]["content"][ # type: ignore + "application/json" + ]["schema"]["properties"]["age"]["title"] + assert result == expected async def test_security_schemes() -> None: diff --git a/tests/test_testing.py b/tests/test_testing.py index 050d3f5..5a3fd4c 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,34 +1,23 @@ -from typing import Optional +from typing import Type, Union import pytest from hypothesis import given, strategies as st -from pydantic import BaseModel -from pydantic.dataclasses import dataclass from quart import Quart -from quart_schema import DataSource, QuartSchema, ResponseReturnValue, validate_request -from quart_schema.typing import PydanticModel +from quart_schema import DataSource, QuartSchema, validate_request +from .helpers import ADetails, DCDetails, MDetails, PyDCDetails, PyDetails +Models = Union[ADetails, DCDetails, MDetails, PyDetails, PyDCDetails] -@dataclass -class DCDetails: - name: str - age: Optional[int] = None - -class Details(BaseModel): - name: str - age: Optional[int] - - -@pytest.mark.parametrize("type_", [DCDetails, Details]) -async def test_send_json(type_: PydanticModel) -> None: +@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]) +async def test_send_json(type_: Type[Models]) -> None: app = Quart(__name__) QuartSchema(app) @app.route("/", methods=["POST"]) @validate_request(type_) - async def index(data: PydanticModel) -> ResponseReturnValue: + async def index(data: Models) -> Models: return data test_client = app.test_client() @@ -36,14 +25,14 @@ async def index(data: PydanticModel) -> ResponseReturnValue: assert (await response.get_json()) == {"name": "bob", "age": 2} -@pytest.mark.parametrize("type_", [DCDetails, Details]) -async def test_send_form(type_: PydanticModel) -> None: +@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]) +async def test_send_form(type_: Type[Models]) -> None: app = Quart(__name__) QuartSchema(app) @app.route("/", methods=["POST"]) @validate_request(type_, source=DataSource.FORM) - async def index(data: PydanticModel) -> PydanticModel: + async def index(data: Models) -> Models: return data test_client = app.test_client() diff --git a/tests/test_validation.py b/tests/test_validation.py index 22d5e9e..e7aa1c7 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from io import BytesIO -from typing import Any, List, Optional, Tuple, TypeVar, Union +from typing import Any, List, Optional, Tuple, Type, TypeVar, Union import pytest +from attrs import define +from msgspec import Struct from pydantic import BaseModel from pydantic.dataclasses import dataclass as pydantic_dataclass from pydantic.functional_validators import BeforeValidator @@ -21,6 +23,7 @@ validate_response, ) from quart_schema.pydantic import File +from .helpers import ADetails, DCDetails, MDetails, PyDCDetails, PyDetails try: from typing import Annotated @@ -28,10 +31,15 @@ from typing_extensions import Annotated # type: ignore -@dataclass -class DCDetails: - name: str - age: Optional[int] = None +@define +class AItem: + count: int + details: ADetails + + +class MItem(Struct): + count: int + details: MDetails @dataclass @@ -40,6 +48,17 @@ class DCItem: details: DCDetails +class PyItem(BaseModel): + count: int + details: PyDetails + + +@pydantic_dataclass +class PyDCItem: + count: int + details: PyDCDetails + + class FileInfo(BaseModel): upload: File @@ -60,39 +79,11 @@ class QueryItem(BaseModel): keys: Annotated[Optional[List[int]], BeforeValidator(_to_list)] = None -class Details(BaseModel): - name: str - age: Optional[int] = None - - -class Item(BaseModel): - count: int - details: Details - - -@pydantic_dataclass -class PyDCDetails: - name: str - age: Optional[int] = None - - -@pydantic_dataclass -class PyDCItem: - count: int - details: PyDCDetails - - VALID_DICT = {"count": 2, "details": {"name": "bob"}} INVALID_DICT = {"count": 2, "name": "bob"} -VALID = Item(count=2, details=Details(name="bob")) -INVALID = Details(name="bob") -VALID_DC = DCItem(count=2, details=DCDetails(name="bob")) -INVALID_DC = DCDetails(name="bob") -VALID_PyDC = PyDCItem(count=2, details=PyDCDetails(name="bob")) -INVALID_PyDC = PyDCDetails(name="bob") -@pytest.mark.parametrize("path", ["/", "/dc", "/pydc"]) +@pytest.mark.parametrize("type_", [AItem, DCItem, MItem, PyItem, PyDCItem]) @pytest.mark.parametrize( "json, status", [ @@ -100,30 +91,25 @@ class PyDCItem: (INVALID_DICT, 400), ], ) -async def test_request_validation(path: str, json: dict, status: int) -> None: +async def test_request_validation( + type_: Type[Union[AItem, DCItem, MItem, PyItem, PyDCItem]], + json: dict, + status: int, +) -> None: app = Quart(__name__) QuartSchema(app) @app.route("/", methods=["POST"]) - @validate_request(Item) - async def item(data: Item) -> ResponseReturnValue: - return "" - - @app.route("/dc", methods=["POST"]) - @validate_request(DCItem) - async def dcitem(data: DCItem) -> ResponseReturnValue: - return "" - - @app.route("/pydc", methods=["POST"]) - @validate_request(PyDCItem) - async def pydcitem(data: PyDCItem) -> ResponseReturnValue: + @validate_request(type_) + async def item(data: Any) -> ResponseReturnValue: return "" test_client = app.test_client() - response = await test_client.post(path, json=json) + response = await test_client.post("/", json=json) assert response.status_code == status +@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]) @pytest.mark.parametrize( "data, status", [ @@ -131,13 +117,17 @@ async def pydcitem(data: PyDCItem) -> ResponseReturnValue: ({"age": 2}, 400), ], ) -async def test_request_form_validation(data: dict, status: int) -> None: +async def test_request_form_validation( + type_: Type[Union[ADetails, DCDetails, MDetails, PyDetails, PyDCDetails]], + data: dict, + status: int, +) -> None: app = Quart(__name__) QuartSchema(app) @app.route("/", methods=["POST"]) - @validate_request(Details, source=DataSource.FORM) - async def item(data: Details) -> ResponseReturnValue: + @validate_request(type_, source=DataSource.FORM) + async def item(data: Any) -> ResponseReturnValue: return "" test_client = app.test_client() @@ -162,29 +152,24 @@ async def item(data: FileInfo) -> ResponseReturnValue: assert (await response.get_data()) == b"ABC" # type: ignore +@pytest.mark.parametrize("type_", [AItem, DCItem, MItem, PyItem, PyDCItem]) @pytest.mark.parametrize( - "model, return_value, status", + "return_value, status", [ - (Item, VALID_DICT, 200), - (Item, INVALID_DICT, 500), - (Item, VALID, 200), - (Item, INVALID, 500), - (DCItem, VALID_DICT, 200), - (DCItem, INVALID_DICT, 500), - (DCItem, VALID_DC, 200), - (DCItem, INVALID_DC, 500), - (PyDCItem, VALID_DICT, 200), - (PyDCItem, INVALID_DICT, 500), - (PyDCItem, VALID_PyDC, 200), - (PyDCItem, INVALID_PyDC, 500), + (VALID_DICT, 200), + (INVALID_DICT, 500), ], ) -async def test_response_validation(model: Any, return_value: Any, status: int) -> None: +async def test_response_validation( + type_: Type[Union[AItem, DCItem, MItem, PyItem, PyDCItem]], + return_value: Any, + status: int, +) -> None: app = Quart(__name__) QuartSchema(app) @app.route("/") - @validate_response(model) + @validate_response(type_) async def item() -> ResponseReturnValue: return return_value @@ -198,7 +183,7 @@ async def test_redirect_validation() -> None: QuartSchema(app) @app.route("/") - @validate_response(Item) + @validate_response(PyItem) async def item() -> ResponseReturnValue: return redirect("/b") @@ -216,7 +201,7 @@ async def item() -> ResponseReturnValue: ) async def test_view_response_validation(return_value: Any, status: int) -> None: class ValidatedView(View): - decorators = [validate_response(Item)] + decorators = [validate_response(PyItem)] methods = ["GET"] def dispatch_request(self, **kwargs: Any) -> ResponseReturnValue: # type: ignore @@ -232,18 +217,21 @@ def dispatch_request(self, **kwargs: Any) -> ResponseReturnValue: # type: ignor assert response.status_code == status -async def test_websocket_validation() -> None: +@pytest.mark.parametrize("type_", [AItem, DCItem, MItem, PyItem, PyDCItem]) +async def test_websocket_validation( + type_: Type[Union[AItem, DCItem, MItem, PyItem, PyDCItem]], +) -> None: app = Quart(__name__) QuartSchema(app) @app.websocket("/ws") async def ws() -> None: - await websocket.receive_as(Item) # type: ignore + await websocket.receive_as(type_) # type: ignore with pytest.raises(SchemaValidationError): - await websocket.receive_as(Item) # type: ignore - await websocket.send_as(VALID_DICT, Item) # type: ignore + await websocket.receive_as(type_) # type: ignore + await websocket.send_as(VALID_DICT, type_) # type: ignore with pytest.raises(SchemaValidationError): - await websocket.send_as(VALID_DICT, Details) # type: ignore + await websocket.send_as(INVALID_DICT, type_) # type: ignore test_client = app.test_client() async with test_client.websocket("/ws") as test_websocket: diff --git a/tox.ini b/tox.ini index d67feb4..9303709 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,8 @@ isolated_build = True [testenv] deps = hypothesis + msgspec + pydantic pytest pytest-asyncio pytest-cov @@ -16,6 +18,7 @@ basepython = python3.12 deps = pydata-sphinx-theme sphinx + sphinx-tabs commands = sphinx-apidoc -e -f -o docs/reference/source/ src/quart_schema/ sphinx-build -b html -d {envtmpdir}/doctrees docs/ docs/_build/html/ @@ -41,7 +44,9 @@ commands = flake8 src/quart_schema/ tests/ basepython = python3.12 deps = hypothesis + msgspec mypy + pydantic pytest commands = mypy src/quart_schema/ tests/