Skip to content

Commit

Permalink
Support nested NewType
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Jun 18, 2024
1 parent 097f872 commit 2738675
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
3 changes: 2 additions & 1 deletion litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from litestar.utils.typing import (
get_origin_or_inner_type,
make_non_optional_union,
unwrap_new_type,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -359,7 +360,7 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re
def for_new_type(self, field_definition: FieldDefinition) -> Schema | Reference:
return self.for_field_definition(
FieldDefinition.from_kwarg(
annotation=field_definition.raw.__supertype__,
annotation=unwrap_new_type(field_definition.raw),
name=field_definition.name,
default=field_definition.default,
)
Expand Down
10 changes: 9 additions & 1 deletion litestar/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
cast,
)

from typing_extensions import Annotated, NotRequired, Required, get_args, get_origin, get_type_hints
from typing_extensions import Annotated, NewType, NotRequired, Required, get_args, get_origin, get_type_hints

from litestar.types.builtin_types import NoneType, UnionTypes

Expand Down Expand Up @@ -174,6 +174,14 @@ def unwrap_annotation(annotation: Any) -> tuple[Any, tuple[Any, ...], set[Any]]:
return annotation, tuple(metadata), wrappers


def unwrap_new_type(new_type: Any) -> Any:
"""Unwrap a (nested) ``typing.NewType``"""
inner = new_type
while isinstance(inner, NewType):
inner = inner.__supertype__
return inner


def get_origin_or_inner_type(annotation: Any) -> Any:
"""Get origin or unwrap it. Returns None for non-generic types.
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_openapi/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,17 @@ async def handler(
app.openapi_schema.paths["/{path_param}"].get.responses["200"].content["application/json"].schema.type # type: ignore[index, union-attr]
== OpenAPIType.STRING
)


def test_unwrap_nested_new_type() -> None:
FancyString = NewType("FancyString", str)
FancierString = NewType("FancierString", FancyString)

@get("/")
async def handler(
param: FancierString,
) -> None:
return None

app = Litestar([handler])
assert app.openapi_schema.paths["/"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]

0 comments on commit 2738675

Please sign in to comment.