Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(OpenAPI): Correctly handle type keyword #3715

Merged
merged 4 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ jobs:
with:
name: coverage-data
path: .coverage.pydantic_v1
include-hidden-files: true

upload-test-coverage:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ jobs:
with:
name: coverage-data
path: .coverage.${{ inputs.python-version }}
include-hidden-files: true
12 changes: 12 additions & 0 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re

if field_definition.is_new_type:
result = self.for_new_type(field_definition)
elif field_definition.is_type_alias_type:
result = self.for_type_alias_type(field_definition)
elif plugin_for_annotation := self.get_plugin_for(field_definition):
result = self.for_plugin(field_definition, plugin_for_annotation)
elif _should_create_enum_schema(field_definition):
Expand Down Expand Up @@ -366,6 +368,16 @@ def for_new_type(self, field_definition: FieldDefinition) -> Schema | Reference:
)
)

def for_type_alias_type(self, field_definition: FieldDefinition) -> Schema | Reference:
return self.for_field_definition(
FieldDefinition.from_kwarg(
annotation=field_definition.annotation.__value__,
name=field_definition.name,
default=field_definition.default,
kwarg_definition=field_definition.kwarg_definition,
)
)

@staticmethod
def for_upload_file(field_definition: FieldDefinition) -> Schema:
"""Create schema for UploadFile.
Expand Down
17 changes: 16 additions & 1 deletion litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast

from msgspec import UnsetType
from typing_extensions import NewType, NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict
from typing_extensions import (
NewType,
NotRequired,
Required,
Self,
TypeAliasType,
get_args,
get_origin,
get_type_hints,
is_typeddict,
)

from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning
from litestar.openapi.spec import Example
Expand Down Expand Up @@ -363,6 +373,11 @@ def is_tuple(self) -> bool:
def is_new_type(self) -> bool:
return isinstance(self.annotation, NewType)

@property
def is_type_alias_type(self) -> bool:
"""Whether the annotation is a ``TypeAliasType``"""
return isinstance(self.annotation, TypeAliasType)

@property
def is_type_var(self) -> bool:
"""Whether the annotation is a TypeVar or not."""
Expand Down
31 changes: 30 additions & 1 deletion tests/unit/test_openapi/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import msgspec
import pytest
from msgspec import Struct
from typing_extensions import Annotated, TypeAlias
from typing_extensions import Annotated, TypeAlias, TypeAliasType

from litestar import Controller, MediaType, get, post
from litestar._openapi.schema_generation.plugins import openapi_schema_plugins
Expand Down Expand Up @@ -615,3 +615,32 @@ async def handler(dep: str) -> None:
assert param.name == f"param{i}"
assert param.required is True
assert param.param_in is ParamType.PATH


def test_type_alias_type() -> None:
@get("/")
def handler(query_param: Annotated[TypeAliasType("IntAlias", int), Parameter(description="foo")]) -> None: # type: ignore[valid-type]
pass

app = Litestar([handler])
param = app.openapi_schema.paths["/"].get.parameters[0] # type: ignore[index, union-attr]
assert param.schema.type is OpenAPIType.INTEGER # type: ignore[union-attr]
# ensure other attributes than the plain type are carried over correctly
assert param.description == "foo"


@pytest.mark.skipif(sys.version_info < (3, 12), reason="type keyword not available before 3.12")
def test_type_alias_type_keyword() -> None:
ctx: Dict[str, Any] = {}
exec("type IntAlias = int", ctx, None)
annotation = ctx["IntAlias"]

@get("/")
def handler(query_param: Annotated[annotation, Parameter(description="foo")]) -> None: # type: ignore[valid-type]
pass

app = Litestar([handler])
param = app.openapi_schema.paths["/"].get.parameters[0] # type: ignore[union-attr, index]
assert param.schema.type is OpenAPIType.INTEGER # type: ignore[union-attr]
# ensure other attributes than the plain type are carried over correctly
assert param.description == "foo"
16 changes: 15 additions & 1 deletion tests/unit/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import annotated_types
import msgspec
import pytest
from typing_extensions import Annotated, NotRequired, Required, TypedDict, get_type_hints
from typing_extensions import Annotated, NotRequired, Required, TypeAliasType, TypedDict, get_type_hints

from litestar import get
from litestar.exceptions import LitestarWarning
Expand Down Expand Up @@ -476,3 +476,17 @@ def handler(foo: Annotated[int, Parameter(default=1)]) -> None:
(record,) = warnings
assert record.category == DeprecationWarning
assert "Deprecated default value specification" in str(record.message)


def test_is_type_alias_type() -> None:
field_definition = FieldDefinition.from_annotation(TypeAliasType("IntAlias", int)) # pyright: ignore
assert field_definition.is_type_alias_type


@pytest.mark.skipif(sys.version_info < (3, 12), reason="type keyword not available before 3.12")
def test_unwrap_type_alias_type_keyword() -> None:
ctx: dict[str, Any] = {}
exec("type IntAlias = int", ctx, None)
annotation = ctx["IntAlias"]
field_definition = FieldDefinition.from_annotation(annotation)
assert field_definition.is_type_alias_type
Loading