Skip to content

Commit

Permalink
Unwrap NewType for OpenAPI schema
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Jun 17, 2024
1 parent a1f5b3b commit 097f872
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
13 changes: 12 additions & 1 deletion litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re

result: Schema | Reference

if plugin_for_annotation := self.get_plugin_for(field_definition):
if field_definition.is_new_type:
result = self.for_new_type(field_definition)
elif plugin_for_annotation := self.get_plugin_for(field_definition):
result = self.for_plugin(field_definition, plugin_for_annotation)
elif _should_create_enum_schema(field_definition):
annotation = _type_or_first_not_none_inner_type(field_definition)
Expand Down Expand Up @@ -354,6 +356,15 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re

return self.process_schema_result(field_definition, result) if isinstance(result, Schema) else result

def for_new_type(self, field_definition: FieldDefinition) -> Schema | Reference:
return self.for_field_definition(
FieldDefinition.from_kwarg(
annotation=field_definition.raw.__supertype__,
name=field_definition.name,
default=field_definition.default,
)
)

@staticmethod
def for_upload_file(field_definition: FieldDefinition) -> Schema:
"""Create schema for UploadFile.
Expand Down
27 changes: 12 additions & 15 deletions litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,10 @@
from copy import deepcopy
from dataclasses import dataclass, is_dataclass, replace
from inspect import Parameter, Signature
from typing import (
Any,
AnyStr,
Callable,
Collection,
ForwardRef,
Literal,
Mapping,
Protocol,
Sequence,
TypeVar,
cast,
)
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast

from msgspec import UnsetType
from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict
from typing_extensions import NewType, NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict

from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning
from litestar.openapi.spec import Example
Expand Down Expand Up @@ -314,7 +302,12 @@ def is_generic(self) -> bool:
def is_simple_type(self) -> bool:
"""Check if the field type is a singleton value (e.g. int, str etc.)."""
return not (
self.is_generic or self.is_optional or self.is_union or self.is_mapping or self.is_non_string_iterable
self.is_generic
or self.is_optional
or self.is_union
or self.is_mapping
or self.is_non_string_iterable
or self.is_new_type
)

@property
Expand Down Expand Up @@ -366,6 +359,10 @@ def is_tuple(self) -> bool:
"""Whether the annotation is a ``tuple`` or not."""
return self.is_subclass_of(tuple)

@property
def is_new_type(self) -> bool:
return isinstance(self.annotation, NewType)

@property
def is_type_var(self) -> bool:
"""Whether the annotation is a TypeVar or not."""
Expand Down
26 changes: 25 additions & 1 deletion tests/unit/test_openapi/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

import pytest
from typing_extensions import Annotated
from typing_extensions import Annotated, NewType

from litestar import Controller, Litestar, Router, get
from litestar._openapi.datastructures import OpenAPIContext
Expand Down Expand Up @@ -380,3 +380,27 @@ async def uuid_path(id: Annotated[UUID, Parameter(description="UUID ID")]) -> UU
response = client.get("/schema/openapi.json")
assert response.json()["paths"]["/str/{id}"]["get"]["parameters"][0]["description"] == "String ID"
assert response.json()["paths"]["/uuid/{id}"]["get"]["parameters"][0]["description"] == "UUID ID"


def test_unwrap_new_type() -> None:
FancyString = NewType("FancyString", str)

@get("/{path_param:str}")
async def handler(
param: FancyString,
optional_param: Optional[FancyString],
path_param: FancyString,
) -> FancyString:
return FancyString("")

app = Litestar([handler])
assert app.openapi_schema.paths["/{path_param}"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[1].schema.one_of == [ # type: ignore[index, union-attr]
Schema(type=OpenAPIType.NULL),
Schema(type=OpenAPIType.STRING),
]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[2].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert (
app.openapi_schema.paths["/{path_param}"].get.responses["200"].content["application/json"].schema.type # type: ignore[index, union-attr]
== OpenAPIType.STRING
)

0 comments on commit 097f872

Please sign in to comment.