Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(apigateway): ignore trailing slashes in routes (APIGatewayRestResolver) #1609

Merged
merged 12 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
# API GW/ALB decode non-safe URI chars; we must support them too
_UNSAFE_URI = "%<> \[\]{}|^" # noqa: W605
_NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"
_ROUTE_REGEX = "^{}$"


class ProxyEventType(Enum):
Expand Down Expand Up @@ -562,7 +563,7 @@ def _has_debug(debug: Optional[bool] = None) -> bool:
return powertools_dev_is_set()

@staticmethod
def _compile_regex(rule: str):
def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):
"""Precompile regex pattern

Logic
Expand Down Expand Up @@ -592,7 +593,7 @@ def _compile_regex(rule: str):
NOTE: See #520 for context
"""
rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule)
return re.compile("^{}$".format(rule_regex))
return re.compile(base_regex.format(rule_regex))

def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
"""Convert the event dict to the corresponding data class"""
Expand Down Expand Up @@ -819,6 +820,24 @@ def __init__(
"""Amazon API Gateway REST and HTTP API v1 payload resolver"""
super().__init__(ProxyEventType.APIGatewayProxyEvent, cors, debug, serializer, strip_prefixes)

# override route to ignore trailing "/" in routes for REST API
def route(
self,
rule: str,
method: Union[str, Union[List[str], Tuple[str]]],
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
):
# remove trailing "/" character from route rule for correct routing behaviour
walmsles marked this conversation as resolved.
Show resolved Hide resolved
return super().route(rule.rstrip("/"), method, cors, compress, cache_control)

# Override _compile_regex to exclude trailing slashes for route resolution
@staticmethod
def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):

return super(APIGatewayRestResolver, APIGatewayRestResolver)._compile_regex(rule, "^{}/*$")


class APIGatewayHttpResolver(ApiGatewayResolver):
current_event: APIGatewayProxyEventV2
Expand Down
7 changes: 6 additions & 1 deletion tests/e2e/event_handler/handlers/alb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@app.post("/todos")
def hello():
def todos():
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
payload = app.current_event.json_body

body = payload.get("body", "Hello World")
Expand All @@ -22,5 +22,10 @@ def hello():
)


@app.get("/hello")
def hello():
return Response(status_code=200, content_type=content_types.TEXT_PLAIN, body="Hello World")


def lambda_handler(event, context):
return app.resolve(event, context)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@app.post("/todos")
def hello():
def todos():
payload = app.current_event.json_body

body = payload.get("body", "Hello World")
Expand All @@ -26,5 +26,10 @@ def hello():
)


@app.get("/hello")
def hello():
return Response(status_code=200, content_type=content_types.TEXT_PLAIN, body="Hello World")
walmsles marked this conversation as resolved.
Show resolved Hide resolved


def lambda_handler(event, context):
return app.resolve(event, context)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@app.post("/todos")
def hello():
def todos():
payload = app.current_event.json_body

body = payload.get("body", "Hello World")
Expand All @@ -26,5 +26,10 @@ def hello():
)


@app.get("/hello")
def hello():
return Response(status_code=200, content_type=content_types.TEXT_PLAIN, body="Hello World")


def lambda_handler(event, context):
return app.resolve(event, context)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@app.post("/todos")
def hello():
def todos():
payload = app.current_event.json_body

body = payload.get("body", "Hello World")
Expand All @@ -26,5 +26,10 @@ def hello():
)


@app.get("/hello")
def hello():
return Response(status_code=200, content_type=content_types.TEXT_PLAIN, body="Hello World")


def lambda_handler(event, context):
return app.resolve(event, context)
8 changes: 8 additions & 0 deletions tests/e2e/event_handler/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def _create_api_gateway_http(self, function: Function):
integration=apigwv2integrations.HttpLambdaIntegration("TodosIntegration", function),
)

apigw.add_routes(
path="/hello",
methods=[apigwv2.HttpMethod.GET],
integration=apigwv2integrations.HttpLambdaIntegration("HelloIntegration", function),
)
CfnOutput(self.stack, "APIGatewayHTTPUrl", value=(apigw.url or ""))

def _create_api_gateway_rest(self, function: Function):
Expand All @@ -72,6 +77,9 @@ def _create_api_gateway_rest(self, function: Function):
todos = apigw.root.add_resource("todos")
todos.add_method("POST", apigwv1.LambdaIntegration(function, proxy=True))

hello = apigw.root.add_resource("hello")
hello.add_method("GET", apigwv1.LambdaIntegration(function, proxy=True))

CfnOutput(self.stack, "APIGatewayRestUrl", value=apigw.url)

def _create_lambda_function_url(self, function: Function):
Expand Down
95 changes: 95 additions & 0 deletions tests/e2e/event_handler/test_paths_ending_with_slash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
from requests import HTTPError, Request

from tests.e2e.utils import data_fetcher


@pytest.fixture
def alb_basic_listener_endpoint(infrastructure: dict) -> str:
dns_name = infrastructure.get("ALBDnsName")
port = infrastructure.get("ALBBasicListenerPort", "")
return f"http://{dns_name}:{port}"


@pytest.fixture
def alb_multi_value_header_listener_endpoint(infrastructure: dict) -> str:
dns_name = infrastructure.get("ALBDnsName")
port = infrastructure.get("ALBMultiValueHeaderListenerPort", "")
return f"http://{dns_name}:{port}"


@pytest.fixture
def apigw_rest_endpoint(infrastructure: dict) -> str:
return infrastructure.get("APIGatewayRestUrl", "")


@pytest.fixture
def apigw_http_endpoint(infrastructure: dict) -> str:
return infrastructure.get("APIGatewayHTTPUrl", "")


@pytest.fixture
def lambda_function_url_endpoint(infrastructure: dict) -> str:
return infrastructure.get("LambdaFunctionUrl", "")


def test_api_gateway_rest_trailing_slash(apigw_rest_endpoint):
# GIVEN
url = f"{apigw_rest_endpoint}hello/"
body = "Hello World"
status_code = 200

# WHEN
response = data_fetcher.get_http_response(
Request(
method="GET",
url=url,
)
)

# THEN
assert response.status_code == status_code
# response.content is a binary string, needs to be decoded to compare with the real string
assert response.content.decode("ascii") == body
walmsles marked this conversation as resolved.
Show resolved Hide resolved
walmsles marked this conversation as resolved.
Show resolved Hide resolved


def test_api_gateway_http_trailing_slash(apigw_http_endpoint):
# GIVEN the URL for the API ends in a trailing slash API gateway should return a 404
url = f"{apigw_http_endpoint}hello/"

# WHEN
with pytest.raises(HTTPError):
data_fetcher.get_http_response(
Request(
method="GET",
url=url,
)
)


def test_lambda_function_url_trailing_slash(lambda_function_url_endpoint):
# GIVEN the URL for the API ends in a trailing slash it should behave as if there was not one
url = f"{lambda_function_url_endpoint}hello/" # the function url endpoint already has the trailing /

# WHEN
with pytest.raises(HTTPError):
data_fetcher.get_http_response(
Request(
method="GET",
url=url,
)
)


def test_alb_url_trailing_slash(alb_multi_value_header_listener_endpoint):
# GIVEN url has a trailing slash - it should behave as if there was not one
url = f"{alb_multi_value_header_listener_endpoint}/hello/"

# WHEN
with pytest.raises(HTTPError):
data_fetcher.get_http_response(
Request(
method="GET",
url=url,
)
)
28 changes: 28 additions & 0 deletions tests/events/albEventPathTrailingSlash.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"requestContext": {
"elb": {
"targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a"
}
},
"httpMethod": "GET",
"path": "/lambda/",
"queryStringParameters": {
"query": "1234ABCD"
},
"headers": {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
"accept-encoding": "gzip",
"accept-language": "en-US,en;q=0.9",
"connection": "keep-alive",
"host": "lambda-alb-123578498.us-east-2.elb.amazonaws.com",
"upgrade-insecure-requests": "1",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36",
"x-amzn-trace-id": "Root=1-5c536348-3d683b8b04734faae651f476",
"x-forwarded-for": "72.12.164.125",
"x-forwarded-port": "80",
"x-forwarded-proto": "http",
"x-imforwards": "20"
},
"body": "Test",
"isBase64Encoded": false
}
80 changes: 80 additions & 0 deletions tests/events/apiGatewayProxyEventPathTrailingSlash.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
{
"version": "1.0",
"resource": "/my/path",
"path": "/my/path/",
"httpMethod": "GET",
"headers": {
"Header1": "value1",
"Header2": "value2"
},
"multiValueHeaders": {
"Header1": [
"value1"
],
"Header2": [
"value1",
"value2"
]
},
"queryStringParameters": {
"parameter1": "value1",
"parameter2": "value"
},
"multiValueQueryStringParameters": {
"parameter1": [
"value1",
"value2"
],
"parameter2": [
"value"
]
},
"requestContext": {
"accountId": "123456789012",
"apiId": "id",
"authorizer": {
"claims": null,
"scopes": null
},
"domainName": "id.execute-api.us-east-1.amazonaws.com",
"domainPrefix": "id",
"extendedRequestId": "request-id",
"httpMethod": "GET",
"identity": {
"accessKey": null,
"accountId": null,
"caller": null,
"cognitoAuthenticationProvider": null,
"cognitoAuthenticationType": null,
"cognitoIdentityId": null,
"cognitoIdentityPoolId": null,
"principalOrgId": null,
"sourceIp": "192.168.0.1/32",
"user": null,
"userAgent": "user-agent",
"userArn": null,
"clientCert": {
"clientCertPem": "CERT_CONTENT",
"subjectDN": "www.example.com",
"issuerDN": "Example issuer",
"serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1",
"validity": {
"notBefore": "May 28 12:30:02 2019 GMT",
"notAfter": "Aug 5 09:36:04 2021 GMT"
}
}
},
"path": "/my/path",
"protocol": "HTTP/1.1",
"requestId": "id=",
"requestTime": "04/Mar/2020:19:15:17 +0000",
"requestTimeEpoch": 1583349317135,
"resourceId": null,
"resourcePath": "/my/path",
"stage": "$default"
},
"pathParameters": null,
"stageVariables": null,
"body": "Hello from Lambda!",
"isBase64Encoded": true
}
Loading