Skip to content

Commit

Permalink
fix(parser): fallback to validate_python when using type[Model] a…
Browse files Browse the repository at this point in the history
…nd nested models (#5313)

* Fix Pydantic limitation

* Add e2e tests

* Reverting change in e2e layer
  • Loading branch information
leandrodamascena authored Oct 7, 2024
1 parent b271c17 commit fe6b335
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 16 deletions.
11 changes: 5 additions & 6 deletions aws_lambda_powertools/utilities/parser/envelopes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, TypeVar

from aws_lambda_powertools.utilities.parser.functions import _retrieve_or_set_model_from_cache
from aws_lambda_powertools.utilities.parser.functions import (
_parse_and_validate_event,
_retrieve_or_set_model_from_cache,
)

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.parser.types import T
Expand Down Expand Up @@ -38,11 +41,7 @@ def _parse(data: dict[str, Any] | Any | None, model: type[T]) -> T | None:
adapter = _retrieve_or_set_model_from_cache(model=model)

logger.debug("parsing event against model")
if isinstance(data, str):
logger.debug("parsing event as string")
return adapter.validate_json(data)

return adapter.validate_python(data)
return _parse_and_validate_event(data=data, adapter=adapter)

@abstractmethod
def parse(self, data: dict[str, Any] | Any | None, model: type[T]):
Expand Down
41 changes: 40 additions & 1 deletion aws_lambda_powertools/utilities/parser/functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import json
import logging
from typing import TYPE_CHECKING, Any

from pydantic import TypeAdapter

Expand All @@ -11,6 +13,8 @@

CACHE_TYPE_ADAPTER = LRUDict(max_items=1024)

logger = logging.getLogger(__name__)


def _retrieve_or_set_model_from_cache(model: type[T]) -> TypeAdapter:
"""
Expand Down Expand Up @@ -38,3 +42,38 @@ def _retrieve_or_set_model_from_cache(model: type[T]) -> TypeAdapter:

CACHE_TYPE_ADAPTER[id_model] = TypeAdapter(model)
return CACHE_TYPE_ADAPTER[id_model]


def _parse_and_validate_event(data: dict[str, Any] | Any, adapter: TypeAdapter):
"""
Parse and validate the event data using the provided adapter.
Params
------
data: dict | Any
The event data to be parsed and validated.
adapter: TypeAdapter
The adapter object used for validation.
Returns:
dict: The validated event data.
Raises:
ValidationError: If the data is invalid or cannot be parsed.
"""
logger.debug("Parsing event against model")

if isinstance(data, str):
logger.debug("Parsing event as string")
try:
return adapter.validate_json(data)
except NotImplementedError:
# See: https://github.com/aws-powertools/powertools-lambda-python/issues/5303
# See: https://github.com/pydantic/pydantic/issues/8890
logger.debug(
"Falling back to Python validation due to Pydantic implementation."
"See issue: https://github.com/aws-powertools/powertools-lambda-python/issues/5303",
)
data = json.loads(data)

return adapter.validate_python(data)
10 changes: 5 additions & 5 deletions aws_lambda_powertools/utilities/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

from aws_lambda_powertools.middleware_factory import lambda_handler_decorator
from aws_lambda_powertools.utilities.parser.exceptions import InvalidEnvelopeError, InvalidModelTypeError
from aws_lambda_powertools.utilities.parser.functions import _retrieve_or_set_model_from_cache
from aws_lambda_powertools.utilities.parser.functions import (
_parse_and_validate_event,
_retrieve_or_set_model_from_cache,
)

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.parser.envelopes.base import Envelope
Expand Down Expand Up @@ -189,10 +192,7 @@ def handler(event: Order, context: LambdaContext):
adapter = _retrieve_or_set_model_from_cache(model=model)

logger.debug("Parsing and validating event model; no envelope used")
if isinstance(event, str):
return adapter.validate_json(event)

return adapter.validate_python(event)
return _parse_and_validate_event(data=event, adapter=adapter)

# Pydantic raises PydanticSchemaGenerationError when the model is not a Pydantic model
# This is seen in the tests where we pass a non-Pydantic model type to the parser or
Expand Down
23 changes: 23 additions & 0 deletions tests/e2e/parser/handlers/handler_with_model_type_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import json
from typing import Any, Dict, Type, Union

from pydantic import BaseModel

from aws_lambda_powertools.utilities.parser import parse
from aws_lambda_powertools.utilities.typing import LambdaContext

AnyInheritedModel = Union[Type[BaseModel], BaseModel]
RawDictOrModel = Union[Dict[str, Any], AnyInheritedModel]


class ModelWithUnionType(BaseModel):
name: str
profile: RawDictOrModel


def lambda_handler(event: ModelWithUnionType, context: LambdaContext):
event = json.dumps(event)

event_parsed = parse(event=event, model=ModelWithUnionType)

return {"name": event_parsed.name}
21 changes: 21 additions & 0 deletions tests/e2e/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def handler_with_dataclass_arn(infrastructure: dict) -> str:
return infrastructure.get("HandlerWithDataclass", "")


@pytest.fixture
def handler_with_type_model_class(infrastructure: dict) -> str:
return infrastructure.get("HandlerWithModelTypeClass", "")


@pytest.mark.xdist_group(name="parser")
def test_parser_with_basic_model(handler_with_basic_model_arn):
# GIVEN
Expand Down Expand Up @@ -66,3 +71,19 @@ def test_parser_with_dataclass(handler_with_dataclass_arn):
ret = parser_execution["Payload"].read().decode("utf-8")

assert "powertools" in ret


@pytest.mark.xdist_group(name="parser")
def test_parser_with_type_model(handler_with_type_model_class):
# GIVEN
payload = json.dumps({"name": "powertools", "profile": {"description": "python", "size": "XXL"}})

# WHEN
parser_execution, _ = data_fetcher.get_lambda_response(
lambda_arn=handler_with_type_model_class,
payload=payload,
)

ret = parser_execution["Payload"].read().decode("utf-8")

assert "powertools" in ret
48 changes: 44 additions & 4 deletions tests/functional/parser/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json
from datetime import datetime
from typing import Any, Dict, Literal, Union

import pydantic
import pytest
from pydantic import ValidationError
from typing_extensions import Annotated

from aws_lambda_powertools.utilities.parser import (
event_parser,
exceptions,
)
from aws_lambda_powertools.utilities.parser import event_parser, exceptions, parse
from aws_lambda_powertools.utilities.parser.envelopes.sqs import SqsEnvelope
from aws_lambda_powertools.utilities.parser.models import SqsModel
from aws_lambda_powertools.utilities.parser.models.event_bridge import EventBridgeModel
from aws_lambda_powertools.utilities.typing import LambdaContext


Expand Down Expand Up @@ -161,3 +162,42 @@ def handler(event: test_input, _: Any) -> str:

ret = handler(test_input, None)
assert ret == expected


def test_parser_with_model_type_model_and_envelope():
event = {
"Records": [
{
"messageId": "19dd0b57-b21e-4ac1-bd88-01bbb068cb78",
"receiptHandle": "MessageReceiptHandle",
"body": EventBridgeModel(
version="version",
id="id",
source="source",
account="account",
time=datetime.now(),
region="region",
resources=[],
detail={"key": "value"},
).model_dump_json(),
"attributes": {
"ApproximateReceiveCount": "1",
"SentTimestamp": "1523232000000",
"SenderId": "123456789012",
"ApproximateFirstReceiveTimestamp": "1523232000001",
},
"messageAttributes": {},
"md5OfBody": "{{{md5_of_body}}}",
"eventSource": "aws:sqs",
"eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:MyQueue",
"awsRegion": "us-east-1",
},
],
}

def handler(event: SqsModel, _: LambdaContext):
parsed_event: EventBridgeModel = parse(event, model=EventBridgeModel, envelope=SqsEnvelope)
print(parsed_event)
assert parsed_event[0].version == "version"

handler(event, LambdaContext())

0 comments on commit fe6b335

Please sign in to comment.