Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(apigateway): ignore trailing slashes in routes (APIGatewayRestResolver) #1609

Merged
merged 12 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
# API GW/ALB decode non-safe URI chars; we must support them too
_UNSAFE_URI = "%<> \[\]{}|^" # noqa: W605
_NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"
_ROUTE_REGEX = "^{}$"


class ProxyEventType(Enum):
Expand Down Expand Up @@ -562,7 +563,7 @@ def _has_debug(debug: Optional[bool] = None) -> bool:
return powertools_dev_is_set()

@staticmethod
def _compile_regex(rule: str):
def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):
"""Precompile regex pattern

Logic
Expand Down Expand Up @@ -592,7 +593,7 @@ def _compile_regex(rule: str):
NOTE: See #520 for context
"""
rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule)
return re.compile("^{}$".format(rule_regex))
return re.compile(base_regex.format(rule_regex))

def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
"""Convert the event dict to the corresponding data class"""
Expand Down Expand Up @@ -819,6 +820,24 @@ def __init__(
"""Amazon API Gateway REST and HTTP API v1 payload resolver"""
super().__init__(ProxyEventType.APIGatewayProxyEvent, cors, debug, serializer, strip_prefixes)

# override route to ignore trailing "/" in routes for REST API
def route(
self,
rule: str,
method: Union[str, Union[List[str], Tuple[str]]],
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
):
# NOTE: see #1552 for more context.
return super().route(rule.rstrip("/"), method, cors, compress, cache_control)

# Override _compile_regex to exclude trailing slashes for route resolution
@staticmethod
def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):

return super(APIGatewayRestResolver, APIGatewayRestResolver)._compile_regex(rule, "^{}/*$")


class APIGatewayHttpResolver(ApiGatewayResolver):
current_event: APIGatewayProxyEventV2
Expand Down
10 changes: 5 additions & 5 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ Before you decorate your functions to handle a given path and HTTP method(s), yo

A resolver will handle request resolution, including [one or more routers](#split-routes-with-router), and give you access to the current event via typed properties.

For resolvers, we provide: `APIGatewayRestResolver`, `APIGatewayHttpResolver`, `ALBResolver`, and `LambdaFunctionUrlResolver` .
For resolvers, we provide: `APIGatewayRestResolver`, `APIGatewayHttpResolver`, `ALBResolver`, and `LambdaFunctionUrlResolver`. From here on, we will default to `APIGatewayRestResolver` across examples.

???+ info
We will use `APIGatewayRestResolver` as the default across examples.
???+ info "Auto-serialization"
We serialize `Dict` responses as JSON, trim whitespace for compact responses, and set content-type to `application/json`.

#### API Gateway REST API

When using Amazon API Gateway REST API to front your Lambda functions, you can use `APIGatewayRestResolver`.

Here's an example on how we can handle the `/todos` path.

???+ info
We automatically serialize `Dict` responses as JSON, trim whitespace for compact responses, and set content-type to `application/json`.
???+ info "Trailing slash in routes"
For `APIGatewayRestResolver`, we seamless handle routes with a trailing slash (`/todos/`).

=== "getting_started_rest_api_resolver.py"

Expand Down
5 changes: 4 additions & 1 deletion tests/e2e/event_handler/handlers/alb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

app = ALBResolver()

# The reason we use post is that whoever is writing tests can easily assert on the
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.


@app.post("/todos")
def hello():
def todos():
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
payload = app.current_event.json_body

body = payload.get("body", "Hello World")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

app = APIGatewayHttpResolver()

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.


@app.post("/todos")
def hello():
def todos():
payload = app.current_event.json_body

body = payload.get("body", "Hello World")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

app = APIGatewayRestResolver()

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.


@app.post("/todos")
def hello():
def todos():
payload = app.current_event.json_body

body = payload.get("body", "Hello World")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

app = LambdaFunctionUrlResolver()

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.


@app.post("/todos")
def hello():
def todos():
payload = app.current_event.json_body

body = payload.get("body", "Hello World")
Expand Down
99 changes: 99 additions & 0 deletions tests/e2e/event_handler/test_paths_ending_with_slash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
from requests import HTTPError, Request

from tests.e2e.utils import data_fetcher


@pytest.fixture
def alb_basic_listener_endpoint(infrastructure: dict) -> str:
dns_name = infrastructure.get("ALBDnsName")
port = infrastructure.get("ALBBasicListenerPort", "")
return f"http://{dns_name}:{port}"


@pytest.fixture
def alb_multi_value_header_listener_endpoint(infrastructure: dict) -> str:
dns_name = infrastructure.get("ALBDnsName")
port = infrastructure.get("ALBMultiValueHeaderListenerPort", "")
return f"http://{dns_name}:{port}"


@pytest.fixture
def apigw_rest_endpoint(infrastructure: dict) -> str:
return infrastructure.get("APIGatewayRestUrl", "")


@pytest.fixture
def apigw_http_endpoint(infrastructure: dict) -> str:
return infrastructure.get("APIGatewayHTTPUrl", "")


@pytest.fixture
def lambda_function_url_endpoint(infrastructure: dict) -> str:
return infrastructure.get("LambdaFunctionUrl", "")


def test_api_gateway_rest_trailing_slash(apigw_rest_endpoint):
# GIVEN API URL ends in a trailing slash
url = f"{apigw_rest_endpoint}todos/"
body = "Hello World"

# WHEN
response = data_fetcher.get_http_response(
Request(
method="POST",
url=url,
json={"body": body},
)
)

# THEN expect a HTTP 200 response
assert response.status_code == 200


def test_api_gateway_http_trailing_slash(apigw_http_endpoint):
# GIVEN the URL for the API ends in a trailing slash API gateway should return a 404
url = f"{apigw_http_endpoint}todos/"
body = "Hello World"

# WHEN calling an invalid URL (with trailing slash) expect HTTPError exception from data_fetcher
with pytest.raises(HTTPError):
data_fetcher.get_http_response(
Request(
method="POST",
url=url,
json={"body": body},
)
)


def test_lambda_function_url_trailing_slash(lambda_function_url_endpoint):
# GIVEN the URL for the API ends in a trailing slash it should behave as if there was not one
url = f"{lambda_function_url_endpoint}todos/" # the function url endpoint already has the trailing /
body = "Hello World"

# WHEN calling an invalid URL (with trailing slash) expect HTTPError exception from data_fetcher
with pytest.raises(HTTPError):
data_fetcher.get_http_response(
Request(
method="POST",
url=url,
json={"body": body},
)
)


def test_alb_url_trailing_slash(alb_multi_value_header_listener_endpoint):
# GIVEN url has a trailing slash - it should behave as if there was not one
url = f"{alb_multi_value_header_listener_endpoint}/todos/"
body = "Hello World"

# WHEN calling an invalid URL (with trailing slash) expect HTTPError exception from data_fetcher
with pytest.raises(HTTPError):
data_fetcher.get_http_response(
Request(
method="POST",
url=url,
json={"body": body},
)
)
28 changes: 28 additions & 0 deletions tests/events/albEventPathTrailingSlash.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"requestContext": {
"elb": {
"targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a"
}
},
"httpMethod": "GET",
"path": "/lambda/",
"queryStringParameters": {
"query": "1234ABCD"
},
"headers": {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
"accept-encoding": "gzip",
"accept-language": "en-US,en;q=0.9",
"connection": "keep-alive",
"host": "lambda-alb-123578498.us-east-2.elb.amazonaws.com",
"upgrade-insecure-requests": "1",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36",
"x-amzn-trace-id": "Root=1-5c536348-3d683b8b04734faae651f476",
"x-forwarded-for": "72.12.164.125",
"x-forwarded-port": "80",
"x-forwarded-proto": "http",
"x-imforwards": "20"
},
"body": "Test",
"isBase64Encoded": false
}
80 changes: 80 additions & 0 deletions tests/events/apiGatewayProxyEventPathTrailingSlash.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
{
"version": "1.0",
"resource": "/my/path",
"path": "/my/path/",
"httpMethod": "GET",
"headers": {
"Header1": "value1",
"Header2": "value2"
},
"multiValueHeaders": {
"Header1": [
"value1"
],
"Header2": [
"value1",
"value2"
]
},
"queryStringParameters": {
"parameter1": "value1",
"parameter2": "value"
},
"multiValueQueryStringParameters": {
"parameter1": [
"value1",
"value2"
],
"parameter2": [
"value"
]
},
"requestContext": {
"accountId": "123456789012",
"apiId": "id",
"authorizer": {
"claims": null,
"scopes": null
},
"domainName": "id.execute-api.us-east-1.amazonaws.com",
"domainPrefix": "id",
"extendedRequestId": "request-id",
"httpMethod": "GET",
"identity": {
"accessKey": null,
"accountId": null,
"caller": null,
"cognitoAuthenticationProvider": null,
"cognitoAuthenticationType": null,
"cognitoIdentityId": null,
"cognitoIdentityPoolId": null,
"principalOrgId": null,
"sourceIp": "192.168.0.1/32",
"user": null,
"userAgent": "user-agent",
"userArn": null,
"clientCert": {
"clientCertPem": "CERT_CONTENT",
"subjectDN": "www.example.com",
"issuerDN": "Example issuer",
"serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1",
"validity": {
"notBefore": "May 28 12:30:02 2019 GMT",
"notAfter": "Aug 5 09:36:04 2021 GMT"
}
}
},
"path": "/my/path",
"protocol": "HTTP/1.1",
"requestId": "id=",
"requestTime": "04/Mar/2020:19:15:17 +0000",
"requestTimeEpoch": 1583349317135,
"resourceId": null,
"resourcePath": "/my/path",
"stage": "$default"
},
"pathParameters": null,
"stageVariables": null,
"body": "Hello from Lambda!",
"isBase64Encoded": true
}
Loading