diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 4dbf753d22..38aaaf096d 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -46,6 +46,7 @@ # API GW/ALB decode non-safe URI chars; we must support them too _UNSAFE_URI = "%<> \[\]{}|^" # noqa: W605 _NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" +_ROUTE_REGEX = "^{}$" class ProxyEventType(Enum): @@ -562,7 +563,7 @@ def _has_debug(debug: Optional[bool] = None) -> bool: return powertools_dev_is_set() @staticmethod - def _compile_regex(rule: str): + def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX): """Precompile regex pattern Logic @@ -592,7 +593,7 @@ def _compile_regex(rule: str): NOTE: See #520 for context """ rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule) - return re.compile("^{}$".format(rule_regex)) + return re.compile(base_regex.format(rule_regex)) def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: """Convert the event dict to the corresponding data class""" @@ -819,6 +820,24 @@ def __init__( """Amazon API Gateway REST and HTTP API v1 payload resolver""" super().__init__(ProxyEventType.APIGatewayProxyEvent, cors, debug, serializer, strip_prefixes) + # override route to ignore trailing "/" in routes for REST API + def route( + self, + rule: str, + method: Union[str, Union[List[str], Tuple[str]]], + cors: Optional[bool] = None, + compress: bool = False, + cache_control: Optional[str] = None, + ): + # NOTE: see #1552 for more context. + return super().route(rule.rstrip("/"), method, cors, compress, cache_control) + + # Override _compile_regex to exclude trailing slashes for route resolution + @staticmethod + def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX): + + return super(APIGatewayRestResolver, APIGatewayRestResolver)._compile_regex(rule, "^{}/*$") + class APIGatewayHttpResolver(ApiGatewayResolver): current_event: APIGatewayProxyEventV2 diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 8ee07890b4..ec6116403e 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -42,10 +42,10 @@ Before you decorate your functions to handle a given path and HTTP method(s), yo A resolver will handle request resolution, including [one or more routers](#split-routes-with-router), and give you access to the current event via typed properties. -For resolvers, we provide: `APIGatewayRestResolver`, `APIGatewayHttpResolver`, `ALBResolver`, and `LambdaFunctionUrlResolver` . +For resolvers, we provide: `APIGatewayRestResolver`, `APIGatewayHttpResolver`, `ALBResolver`, and `LambdaFunctionUrlResolver`. From here on, we will default to `APIGatewayRestResolver` across examples. -???+ info - We will use `APIGatewayRestResolver` as the default across examples. +???+ info "Auto-serialization" + We serialize `Dict` responses as JSON, trim whitespace for compact responses, and set content-type to `application/json`. #### API Gateway REST API @@ -53,8 +53,8 @@ When using Amazon API Gateway REST API to front your Lambda functions, you can u Here's an example on how we can handle the `/todos` path. -???+ info - We automatically serialize `Dict` responses as JSON, trim whitespace for compact responses, and set content-type to `application/json`. +???+ info "Trailing slash in routes" + For `APIGatewayRestResolver`, we seamless handle routes with a trailing slash (`/todos/`). === "getting_started_rest_api_resolver.py" diff --git a/tests/e2e/event_handler/handlers/alb_handler.py b/tests/e2e/event_handler/handlers/alb_handler.py index 0e386c82c5..26746284ae 100644 --- a/tests/e2e/event_handler/handlers/alb_handler.py +++ b/tests/e2e/event_handler/handlers/alb_handler.py @@ -2,9 +2,12 @@ app = ALBResolver() +# The reason we use post is that whoever is writing tests can easily assert on the +# content being sent (body, headers, cookies, content-type) to reduce cognitive load. + @app.post("/todos") -def hello(): +def todos(): payload = app.current_event.json_body body = payload.get("body", "Hello World") diff --git a/tests/e2e/event_handler/handlers/api_gateway_http_handler.py b/tests/e2e/event_handler/handlers/api_gateway_http_handler.py index 9edacc1c80..1012af7b3f 100644 --- a/tests/e2e/event_handler/handlers/api_gateway_http_handler.py +++ b/tests/e2e/event_handler/handlers/api_gateway_http_handler.py @@ -6,9 +6,12 @@ app = APIGatewayHttpResolver() +# The reason we use post is that whoever is writing tests can easily assert on the +# content being sent (body, headers, cookies, content-type) to reduce cognitive load. + @app.post("/todos") -def hello(): +def todos(): payload = app.current_event.json_body body = payload.get("body", "Hello World") diff --git a/tests/e2e/event_handler/handlers/api_gateway_rest_handler.py b/tests/e2e/event_handler/handlers/api_gateway_rest_handler.py index 3127634e99..d52e2728ca 100644 --- a/tests/e2e/event_handler/handlers/api_gateway_rest_handler.py +++ b/tests/e2e/event_handler/handlers/api_gateway_rest_handler.py @@ -6,9 +6,12 @@ app = APIGatewayRestResolver() +# The reason we use post is that whoever is writing tests can easily assert on the +# content being sent (body, headers, cookies, content-type) to reduce cognitive load. + @app.post("/todos") -def hello(): +def todos(): payload = app.current_event.json_body body = payload.get("body", "Hello World") diff --git a/tests/e2e/event_handler/handlers/lambda_function_url_handler.py b/tests/e2e/event_handler/handlers/lambda_function_url_handler.py index 884704893a..f90037afc7 100644 --- a/tests/e2e/event_handler/handlers/lambda_function_url_handler.py +++ b/tests/e2e/event_handler/handlers/lambda_function_url_handler.py @@ -6,9 +6,12 @@ app = LambdaFunctionUrlResolver() +# The reason we use post is that whoever is writing tests can easily assert on the +# content being sent (body, headers, cookies, content-type) to reduce cognitive load. + @app.post("/todos") -def hello(): +def todos(): payload = app.current_event.json_body body = payload.get("body", "Hello World") diff --git a/tests/e2e/event_handler/test_paths_ending_with_slash.py b/tests/e2e/event_handler/test_paths_ending_with_slash.py new file mode 100644 index 0000000000..4c1461d6fc --- /dev/null +++ b/tests/e2e/event_handler/test_paths_ending_with_slash.py @@ -0,0 +1,99 @@ +import pytest +from requests import HTTPError, Request + +from tests.e2e.utils import data_fetcher + + +@pytest.fixture +def alb_basic_listener_endpoint(infrastructure: dict) -> str: + dns_name = infrastructure.get("ALBDnsName") + port = infrastructure.get("ALBBasicListenerPort", "") + return f"http://{dns_name}:{port}" + + +@pytest.fixture +def alb_multi_value_header_listener_endpoint(infrastructure: dict) -> str: + dns_name = infrastructure.get("ALBDnsName") + port = infrastructure.get("ALBMultiValueHeaderListenerPort", "") + return f"http://{dns_name}:{port}" + + +@pytest.fixture +def apigw_rest_endpoint(infrastructure: dict) -> str: + return infrastructure.get("APIGatewayRestUrl", "") + + +@pytest.fixture +def apigw_http_endpoint(infrastructure: dict) -> str: + return infrastructure.get("APIGatewayHTTPUrl", "") + + +@pytest.fixture +def lambda_function_url_endpoint(infrastructure: dict) -> str: + return infrastructure.get("LambdaFunctionUrl", "") + + +def test_api_gateway_rest_trailing_slash(apigw_rest_endpoint): + # GIVEN API URL ends in a trailing slash + url = f"{apigw_rest_endpoint}todos/" + body = "Hello World" + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + json={"body": body}, + ) + ) + + # THEN expect a HTTP 200 response + assert response.status_code == 200 + + +def test_api_gateway_http_trailing_slash(apigw_http_endpoint): + # GIVEN the URL for the API ends in a trailing slash API gateway should return a 404 + url = f"{apigw_http_endpoint}todos/" + body = "Hello World" + + # WHEN calling an invalid URL (with trailing slash) expect HTTPError exception from data_fetcher + with pytest.raises(HTTPError): + data_fetcher.get_http_response( + Request( + method="POST", + url=url, + json={"body": body}, + ) + ) + + +def test_lambda_function_url_trailing_slash(lambda_function_url_endpoint): + # GIVEN the URL for the API ends in a trailing slash it should behave as if there was not one + url = f"{lambda_function_url_endpoint}todos/" # the function url endpoint already has the trailing / + body = "Hello World" + + # WHEN calling an invalid URL (with trailing slash) expect HTTPError exception from data_fetcher + with pytest.raises(HTTPError): + data_fetcher.get_http_response( + Request( + method="POST", + url=url, + json={"body": body}, + ) + ) + + +def test_alb_url_trailing_slash(alb_multi_value_header_listener_endpoint): + # GIVEN url has a trailing slash - it should behave as if there was not one + url = f"{alb_multi_value_header_listener_endpoint}/todos/" + body = "Hello World" + + # WHEN calling an invalid URL (with trailing slash) expect HTTPError exception from data_fetcher + with pytest.raises(HTTPError): + data_fetcher.get_http_response( + Request( + method="POST", + url=url, + json={"body": body}, + ) + ) diff --git a/tests/events/albEventPathTrailingSlash.json b/tests/events/albEventPathTrailingSlash.json new file mode 100644 index 0000000000..c517a3f6b0 --- /dev/null +++ b/tests/events/albEventPathTrailingSlash.json @@ -0,0 +1,28 @@ +{ + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" + } + }, + "httpMethod": "GET", + "path": "/lambda/", + "queryStringParameters": { + "query": "1234ABCD" + }, + "headers": { + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8", + "accept-encoding": "gzip", + "accept-language": "en-US,en;q=0.9", + "connection": "keep-alive", + "host": "lambda-alb-123578498.us-east-2.elb.amazonaws.com", + "upgrade-insecure-requests": "1", + "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36", + "x-amzn-trace-id": "Root=1-5c536348-3d683b8b04734faae651f476", + "x-forwarded-for": "72.12.164.125", + "x-forwarded-port": "80", + "x-forwarded-proto": "http", + "x-imforwards": "20" + }, + "body": "Test", + "isBase64Encoded": false + } \ No newline at end of file diff --git a/tests/events/apiGatewayProxyEventPathTrailingSlash.json b/tests/events/apiGatewayProxyEventPathTrailingSlash.json new file mode 100644 index 0000000000..8a321d96c8 --- /dev/null +++ b/tests/events/apiGatewayProxyEventPathTrailingSlash.json @@ -0,0 +1,80 @@ +{ + "version": "1.0", + "resource": "/my/path", + "path": "/my/path/", + "httpMethod": "GET", + "headers": { + "Header1": "value1", + "Header2": "value2" + }, + "multiValueHeaders": { + "Header1": [ + "value1" + ], + "Header2": [ + "value1", + "value2" + ] + }, + "queryStringParameters": { + "parameter1": "value1", + "parameter2": "value" + }, + "multiValueQueryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": { + "claims": null, + "scopes": null + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": null, + "accountId": null, + "caller": null, + "cognitoAuthenticationProvider": null, + "cognitoAuthenticationType": null, + "cognitoIdentityId": null, + "cognitoIdentityPoolId": null, + "principalOrgId": null, + "sourceIp": "192.168.0.1/32", + "user": null, + "userAgent": "user-agent", + "userArn": null, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "path": "/my/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": null, + "resourcePath": "/my/path", + "stage": "$default" + }, + "pathParameters": null, + "stageVariables": null, + "body": "Hello from Lambda!", + "isBase64Encoded": true + } \ No newline at end of file diff --git a/tests/events/apiGatewayProxyV2EventPathTrailingSlash.json b/tests/events/apiGatewayProxyV2EventPathTrailingSlash.json new file mode 100644 index 0000000000..dfb0d98f2e --- /dev/null +++ b/tests/events/apiGatewayProxyV2EventPathTrailingSlash.json @@ -0,0 +1,69 @@ +{ + "version": "2.0", + "routeKey": "$default", + "rawPath": "/my/path/", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": [ + "cookie1", + "cookie2" + ], + "headers": { + "Header1": "value1", + "Header2": "value1,value2" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authentication": { + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "authorizer": { + "jwt": { + "claims": { + "claim1": "value1", + "claim2": "value2" + }, + "scopes": [ + "scope1", + "scope2" + ] + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": "POST", + "path": "/my/path", + "protocol": "HTTP/1.1", + "sourceIp": "192.168.0.1/32", + "userAgent": "agent" + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390 + }, + "body": "{\"message\": \"hello world\", \"username\": \"tom\"}", + "pathParameters": { + "parameter1": "value1" + }, + "isBase64Encoded": false, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2" + } + } \ No newline at end of file diff --git a/tests/events/lambdaFunctionUrlEventPathTrailingSlash.json b/tests/events/lambdaFunctionUrlEventPathTrailingSlash.json new file mode 100644 index 0000000000..b1f8226518 --- /dev/null +++ b/tests/events/lambdaFunctionUrlEventPathTrailingSlash.json @@ -0,0 +1,52 @@ +{ + "version": "2.0", + "routeKey": "$default", + "rawPath": "/my/path/", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": [ + "cookie1", + "cookie2" + ], + "headers": { + "header1": "value1", + "header2": "value1,value2" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "", + "authentication": null, + "authorizer": { + "iam": { + "accessKey": "AKIA...", + "accountId": "111122223333", + "callerId": "AIDA...", + "cognitoIdentity": null, + "principalOrgId": null, + "userArn": "arn:aws:iam::111122223333:user/example-user", + "userId": "AIDA..." + } + }, + "domainName": ".lambda-url.us-west-2.on.aws", + "domainPrefix": "", + "http": { + "method": "POST", + "path": "/my/path", + "protocol": "HTTP/1.1", + "sourceIp": "123.123.123.123", + "userAgent": "agent" + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390 + }, + "body": "Hello from client!", + "pathParameters": null, + "isBase64Encoded": false, + "stageVariables": null + } \ No newline at end of file diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 8491754b65..a78d3747d2 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -53,6 +53,7 @@ def read_media(file_name: str) -> bytes: LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") +LOAD_GW_EVENT_TRAILING_SLASH = load_event("apiGatewayProxyEventPathTrailingSlash.json") def test_alb_event(): @@ -76,6 +77,27 @@ def foo(): assert result["body"] == "foo" +def test_alb_event_path_trailing_slash(json_dump): + # GIVEN an Application Load Balancer proxy type event + app = ALBResolver() + + @app.get("/lambda") + def foo(): + assert isinstance(app.current_event, ALBEvent) + assert app.lambda_context == {} + assert app.current_event.request_context.elb_target_group_arn is not None + return Response(200, content_types.TEXT_HTML, "foo") + + # WHEN calling the event handler using path with trailing "/" + result = app(load_event("albEventPathTrailingSlash.json"), {}) + + # THEN + assert result["statusCode"] == 404 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 404, "message": "Not found"} + assert result["body"] == json_dump(expected) + + def test_api_gateway_v1(): # GIVEN a Http API V1 proxy type event app = APIGatewayRestResolver() @@ -96,6 +118,23 @@ def get_lambda() -> Response: assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] +def test_api_gateway_v1_path_trailing_slash(): + # GIVEN a Http API V1 proxy type event + app = APIGatewayRestResolver() + + @app.get("/my/path") + def get_lambda() -> Response: + return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"})) + + # WHEN calling the event handler + result = app(LOAD_GW_EVENT_TRAILING_SLASH, {}) + + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEvent + assert result["statusCode"] == 200 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] + + def test_api_gateway_v1_cookies(): # GIVEN a Http API V1 proxy type event app = APIGatewayRestResolver() @@ -134,6 +173,24 @@ def get_lambda() -> Response: assert result["body"] == "foo" +def test_api_gateway_event_path_trailing_slash(json_dump): + # GIVEN a Rest API Gateway proxy type event + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) + + @app.get("/my/path") + def get_lambda() -> Response: + assert isinstance(app.current_event, APIGatewayProxyEvent) + return Response(200, content_types.TEXT_HTML, "foo") + + # WHEN calling the event handler + result = app(LOAD_GW_EVENT_TRAILING_SLASH, {}) + # THEN + assert result["statusCode"] == 404 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] + expected = {"statusCode": 404, "message": "Not found"} + assert result["body"] == json_dump(expected) + + def test_api_gateway_v2(): # GIVEN a Http API V2 proxy type event app = APIGatewayHttpResolver() @@ -156,6 +213,25 @@ def my_path() -> Response: assert result["body"] == "tom" +def test_api_gateway_v2_http_path_trailing_slash(json_dump): + # GIVEN a Http API V2 proxy type event + app = APIGatewayHttpResolver() + + @app.post("/my/path") + def my_path() -> Response: + post_data = app.current_event.json_body + return Response(200, content_types.TEXT_PLAIN, post_data["username"]) + + # WHEN calling the event handler + result = app(load_event("apiGatewayProxyV2EventPathTrailingSlash.json"), {}) + + # THEN expect a 404 response + assert result["statusCode"] == 404 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 404, "message": "Not found"} + assert result["body"] == json_dump(expected) + + def test_api_gateway_v2_cookies(): # GIVEN a Http API V2 proxy type event app = APIGatewayHttpResolver() diff --git a/tests/functional/event_handler/test_lambda_function_url.py b/tests/functional/event_handler/test_lambda_function_url.py index aacbc94129..41baed68a7 100644 --- a/tests/functional/event_handler/test_lambda_function_url.py +++ b/tests/functional/event_handler/test_lambda_function_url.py @@ -30,6 +30,22 @@ def foo(): assert result["body"] == "foo" +def test_lambda_function_url_event_path_trailing_slash(): + # GIVEN a Lambda Function Url type event + app = LambdaFunctionUrlResolver() + + @app.post("/my/path") + def foo(): + return Response(200, content_types.TEXT_HTML, "foo") + + # WHEN calling the event handler with an event with a trailing slash + result = app(load_event("lambdaFunctionUrlEventPathTrailingSlash.json"), {}) + + # THEN return a 404 error + assert result["statusCode"] == 404 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + + def test_lambda_function_url_event_with_cookies(): # GIVEN a Lambda Function Url type event app = LambdaFunctionUrlResolver()