From 097f8728c4525405e58164602cbe43a9fa93399d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= <25355197+provinzkraut@users.noreply.github.com> Date: Mon, 17 Jun 2024 18:49:05 +0200 Subject: [PATCH] Unwrap NewType for OpenAPI schema --- litestar/_openapi/schema_generation/schema.py | 13 ++++++++- litestar/typing.py | 27 +++++++++---------- tests/unit/test_openapi/test_parameters.py | 26 +++++++++++++++++- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/litestar/_openapi/schema_generation/schema.py b/litestar/_openapi/schema_generation/schema.py index 49cfa9acf7..add2319f2b 100644 --- a/litestar/_openapi/schema_generation/schema.py +++ b/litestar/_openapi/schema_generation/schema.py @@ -325,7 +325,9 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re result: Schema | Reference - if plugin_for_annotation := self.get_plugin_for(field_definition): + if field_definition.is_new_type: + result = self.for_new_type(field_definition) + elif plugin_for_annotation := self.get_plugin_for(field_definition): result = self.for_plugin(field_definition, plugin_for_annotation) elif _should_create_enum_schema(field_definition): annotation = _type_or_first_not_none_inner_type(field_definition) @@ -354,6 +356,15 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re return self.process_schema_result(field_definition, result) if isinstance(result, Schema) else result + def for_new_type(self, field_definition: FieldDefinition) -> Schema | Reference: + return self.for_field_definition( + FieldDefinition.from_kwarg( + annotation=field_definition.raw.__supertype__, + name=field_definition.name, + default=field_definition.default, + ) + ) + @staticmethod def for_upload_file(field_definition: FieldDefinition) -> Schema: """Create schema for UploadFile. diff --git a/litestar/typing.py b/litestar/typing.py index fadfbfd564..9646bb0aec 100644 --- a/litestar/typing.py +++ b/litestar/typing.py @@ -5,22 +5,10 @@ from copy import deepcopy from dataclasses import dataclass, is_dataclass, replace from inspect import Parameter, Signature -from typing import ( - Any, - AnyStr, - Callable, - Collection, - ForwardRef, - Literal, - Mapping, - Protocol, - Sequence, - TypeVar, - cast, -) +from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast from msgspec import UnsetType -from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict +from typing_extensions import NewType, NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning from litestar.openapi.spec import Example @@ -314,7 +302,12 @@ def is_generic(self) -> bool: def is_simple_type(self) -> bool: """Check if the field type is a singleton value (e.g. int, str etc.).""" return not ( - self.is_generic or self.is_optional or self.is_union or self.is_mapping or self.is_non_string_iterable + self.is_generic + or self.is_optional + or self.is_union + or self.is_mapping + or self.is_non_string_iterable + or self.is_new_type ) @property @@ -366,6 +359,10 @@ def is_tuple(self) -> bool: """Whether the annotation is a ``tuple`` or not.""" return self.is_subclass_of(tuple) + @property + def is_new_type(self) -> bool: + return isinstance(self.annotation, NewType) + @property def is_type_var(self) -> bool: """Whether the annotation is a TypeVar or not.""" diff --git a/tests/unit/test_openapi/test_parameters.py b/tests/unit/test_openapi/test_parameters.py index 48cea53df0..07cbad566c 100644 --- a/tests/unit/test_openapi/test_parameters.py +++ b/tests/unit/test_openapi/test_parameters.py @@ -2,7 +2,7 @@ from uuid import UUID import pytest -from typing_extensions import Annotated +from typing_extensions import Annotated, NewType from litestar import Controller, Litestar, Router, get from litestar._openapi.datastructures import OpenAPIContext @@ -380,3 +380,27 @@ async def uuid_path(id: Annotated[UUID, Parameter(description="UUID ID")]) -> UU response = client.get("/schema/openapi.json") assert response.json()["paths"]["/str/{id}"]["get"]["parameters"][0]["description"] == "String ID" assert response.json()["paths"]["/uuid/{id}"]["get"]["parameters"][0]["description"] == "UUID ID" + + +def test_unwrap_new_type() -> None: + FancyString = NewType("FancyString", str) + + @get("/{path_param:str}") + async def handler( + param: FancyString, + optional_param: Optional[FancyString], + path_param: FancyString, + ) -> FancyString: + return FancyString("") + + app = Litestar([handler]) + assert app.openapi_schema.paths["/{path_param}"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr] + assert app.openapi_schema.paths["/{path_param}"].get.parameters[1].schema.one_of == [ # type: ignore[index, union-attr] + Schema(type=OpenAPIType.NULL), + Schema(type=OpenAPIType.STRING), + ] + assert app.openapi_schema.paths["/{path_param}"].get.parameters[2].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr] + assert ( + app.openapi_schema.paths["/{path_param}"].get.responses["200"].content["application/json"].schema.type # type: ignore[index, union-attr] + == OpenAPIType.STRING + )