diff --git a/aws_lambda_powertools/utilities/parser/envelopes/base.py b/aws_lambda_powertools/utilities/parser/envelopes/base.py index 14b5c0f0a3..dbd76eafe7 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/base.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/base.py @@ -4,7 +4,10 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, TypeVar -from aws_lambda_powertools.utilities.parser.functions import _retrieve_or_set_model_from_cache +from aws_lambda_powertools.utilities.parser.functions import ( + _parse_and_validate_event, + _retrieve_or_set_model_from_cache, +) if TYPE_CHECKING: from aws_lambda_powertools.utilities.parser.types import T @@ -38,11 +41,7 @@ def _parse(data: dict[str, Any] | Any | None, model: type[T]) -> T | None: adapter = _retrieve_or_set_model_from_cache(model=model) logger.debug("parsing event against model") - if isinstance(data, str): - logger.debug("parsing event as string") - return adapter.validate_json(data) - - return adapter.validate_python(data) + return _parse_and_validate_event(data=data, adapter=adapter) @abstractmethod def parse(self, data: dict[str, Any] | Any | None, model: type[T]): diff --git a/aws_lambda_powertools/utilities/parser/functions.py b/aws_lambda_powertools/utilities/parser/functions.py index 4cf3f13139..b9a35176a1 100644 --- a/aws_lambda_powertools/utilities/parser/functions.py +++ b/aws_lambda_powertools/utilities/parser/functions.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import json +import logging +from typing import TYPE_CHECKING, Any from pydantic import TypeAdapter @@ -11,6 +13,8 @@ CACHE_TYPE_ADAPTER = LRUDict(max_items=1024) +logger = logging.getLogger(__name__) + def _retrieve_or_set_model_from_cache(model: type[T]) -> TypeAdapter: """ @@ -38,3 +42,38 @@ def _retrieve_or_set_model_from_cache(model: type[T]) -> TypeAdapter: CACHE_TYPE_ADAPTER[id_model] = TypeAdapter(model) return CACHE_TYPE_ADAPTER[id_model] + + +def _parse_and_validate_event(data: dict[str, Any] | Any, adapter: TypeAdapter): + """ + Parse and validate the event data using the provided adapter. + + Params + ------ + data: dict | Any + The event data to be parsed and validated. + adapter: TypeAdapter + The adapter object used for validation. + + Returns: + dict: The validated event data. + + Raises: + ValidationError: If the data is invalid or cannot be parsed. + """ + logger.debug("Parsing event against model") + + if isinstance(data, str): + logger.debug("Parsing event as string") + try: + return adapter.validate_json(data) + except NotImplementedError: + # See: https://github.com/aws-powertools/powertools-lambda-python/issues/5303 + # See: https://github.com/pydantic/pydantic/issues/8890 + logger.debug( + "Falling back to Python validation due to Pydantic implementation." + "See issue: https://github.com/aws-powertools/powertools-lambda-python/issues/5303", + ) + data = json.loads(data) + + return adapter.validate_python(data) diff --git a/aws_lambda_powertools/utilities/parser/parser.py b/aws_lambda_powertools/utilities/parser/parser.py index fd0b298bd7..42ffbbd22c 100644 --- a/aws_lambda_powertools/utilities/parser/parser.py +++ b/aws_lambda_powertools/utilities/parser/parser.py @@ -8,7 +8,10 @@ from aws_lambda_powertools.middleware_factory import lambda_handler_decorator from aws_lambda_powertools.utilities.parser.exceptions import InvalidEnvelopeError, InvalidModelTypeError -from aws_lambda_powertools.utilities.parser.functions import _retrieve_or_set_model_from_cache +from aws_lambda_powertools.utilities.parser.functions import ( + _parse_and_validate_event, + _retrieve_or_set_model_from_cache, +) if TYPE_CHECKING: from aws_lambda_powertools.utilities.parser.envelopes.base import Envelope @@ -189,10 +192,7 @@ def handler(event: Order, context: LambdaContext): adapter = _retrieve_or_set_model_from_cache(model=model) logger.debug("Parsing and validating event model; no envelope used") - if isinstance(event, str): - return adapter.validate_json(event) - - return adapter.validate_python(event) + return _parse_and_validate_event(data=event, adapter=adapter) # Pydantic raises PydanticSchemaGenerationError when the model is not a Pydantic model # This is seen in the tests where we pass a non-Pydantic model type to the parser or diff --git a/tests/e2e/parser/handlers/handler_with_model_type_class.py b/tests/e2e/parser/handlers/handler_with_model_type_class.py new file mode 100644 index 0000000000..7e635dee13 --- /dev/null +++ b/tests/e2e/parser/handlers/handler_with_model_type_class.py @@ -0,0 +1,23 @@ +import json +from typing import Any, Dict, Type, Union + +from pydantic import BaseModel + +from aws_lambda_powertools.utilities.parser import parse +from aws_lambda_powertools.utilities.typing import LambdaContext + +AnyInheritedModel = Union[Type[BaseModel], BaseModel] +RawDictOrModel = Union[Dict[str, Any], AnyInheritedModel] + + +class ModelWithUnionType(BaseModel): + name: str + profile: RawDictOrModel + + +def lambda_handler(event: ModelWithUnionType, context: LambdaContext): + event = json.dumps(event) + + event_parsed = parse(event=event, model=ModelWithUnionType) + + return {"name": event_parsed.name} diff --git a/tests/e2e/parser/test_parser.py b/tests/e2e/parser/test_parser.py index ae0b75b344..aa52889aea 100644 --- a/tests/e2e/parser/test_parser.py +++ b/tests/e2e/parser/test_parser.py @@ -20,6 +20,11 @@ def handler_with_dataclass_arn(infrastructure: dict) -> str: return infrastructure.get("HandlerWithDataclass", "") +@pytest.fixture +def handler_with_type_model_class(infrastructure: dict) -> str: + return infrastructure.get("HandlerWithModelTypeClass", "") + + @pytest.mark.xdist_group(name="parser") def test_parser_with_basic_model(handler_with_basic_model_arn): # GIVEN @@ -66,3 +71,19 @@ def test_parser_with_dataclass(handler_with_dataclass_arn): ret = parser_execution["Payload"].read().decode("utf-8") assert "powertools" in ret + + +@pytest.mark.xdist_group(name="parser") +def test_parser_with_type_model(handler_with_type_model_class): + # GIVEN + payload = json.dumps({"name": "powertools", "profile": {"description": "python", "size": "XXL"}}) + + # WHEN + parser_execution, _ = data_fetcher.get_lambda_response( + lambda_arn=handler_with_type_model_class, + payload=payload, + ) + + ret = parser_execution["Payload"].read().decode("utf-8") + + assert "powertools" in ret diff --git a/tests/functional/parser/test_parser.py b/tests/functional/parser/test_parser.py index d4208c203a..c7c90b7026 100644 --- a/tests/functional/parser/test_parser.py +++ b/tests/functional/parser/test_parser.py @@ -1,4 +1,5 @@ import json +from datetime import datetime from typing import Any, Dict, Literal, Union import pydantic @@ -6,10 +7,10 @@ from pydantic import ValidationError from typing_extensions import Annotated -from aws_lambda_powertools.utilities.parser import ( - event_parser, - exceptions, -) +from aws_lambda_powertools.utilities.parser import event_parser, exceptions, parse +from aws_lambda_powertools.utilities.parser.envelopes.sqs import SqsEnvelope +from aws_lambda_powertools.utilities.parser.models import SqsModel +from aws_lambda_powertools.utilities.parser.models.event_bridge import EventBridgeModel from aws_lambda_powertools.utilities.typing import LambdaContext @@ -161,3 +162,42 @@ def handler(event: test_input, _: Any) -> str: ret = handler(test_input, None) assert ret == expected + + +def test_parser_with_model_type_model_and_envelope(): + event = { + "Records": [ + { + "messageId": "19dd0b57-b21e-4ac1-bd88-01bbb068cb78", + "receiptHandle": "MessageReceiptHandle", + "body": EventBridgeModel( + version="version", + id="id", + source="source", + account="account", + time=datetime.now(), + region="region", + resources=[], + detail={"key": "value"}, + ).model_dump_json(), + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1523232000000", + "SenderId": "123456789012", + "ApproximateFirstReceiveTimestamp": "1523232000001", + }, + "messageAttributes": {}, + "md5OfBody": "{{{md5_of_body}}}", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:MyQueue", + "awsRegion": "us-east-1", + }, + ], + } + + def handler(event: SqsModel, _: LambdaContext): + parsed_event: EventBridgeModel = parse(event, model=EventBridgeModel, envelope=SqsEnvelope) + print(parsed_event) + assert parsed_event[0].version == "version" + + handler(event, LambdaContext())