Skip to content

Commit

Permalink
Use the globals of the function when evaluating the return type for `…
Browse files Browse the repository at this point in the history
…PlainSerializer` and `WrapSerializer` functions (#11008)
  • Loading branch information
Viicos authored and sydney-runkle committed Dec 3, 2024
1 parent cb962c1 commit b2c4548
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
13 changes: 8 additions & 5 deletions pydantic/functional_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
The Pydantic core schema.
"""
schema = handler(source_type)
globalns, localns = handler._get_types_namespace()
try:
# Do not pass in globals as the function could be defined in a different module.
# Instead, let `get_function_return_type` infer the globals to use, but still pass
# in locals that may contain a parent/rebuild namespace:
return_type = _decorators.get_function_return_type(
self.func,
self.return_type,
globalns=globalns,
localns=localns,
localns=handler._get_types_namespace().locals,
)
except NameError as e:
raise PydanticUndefinedAnnotation.from_name_error(e) from e
Expand Down Expand Up @@ -166,11 +167,13 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
schema = handler(source_type)
globalns, localns = handler._get_types_namespace()
try:
# Do not pass in globals as the function could be defined in a different module.
# Instead, let `get_function_return_type` infer the globals to use, but still pass
# in locals that may contain a parent/rebuild namespace:
return_type = _decorators.get_function_return_type(
self.func,
self.return_type,
globalns=globalns,
localns=localns,
localns=handler._get_types_namespace().locals,
)
except NameError as e:
raise PydanticUndefinedAnnotation.from_name_error(e) from e
Expand Down
12 changes: 11 additions & 1 deletion tests/test_forward_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,12 +1314,22 @@ def test_uses_the_correct_globals_to_resolve_forward_refs_on_serializers(create_
# we use the globals of the underlying func to resolve the return type.
@create_module
def module_1():
from pydantic import BaseModel, field_serializer # or model_serializer, computed_field
from typing_extensions import Annotated

from pydantic import (
BaseModel,
PlainSerializer, # or WrapSerializer
field_serializer, # or model_serializer, computed_field
)

MyStr = str

def ser_func(value) -> 'MyStr':
return str(value)

class Model(BaseModel):
a: int
b: Annotated[int, PlainSerializer(ser_func)]

@field_serializer('a')
def ser(self, value) -> 'MyStr':
Expand Down

0 comments on commit b2c4548

Please sign in to comment.