Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ELB Support / Fix: APIGW v2 cookies response #155

Merged
merged 9 commits into from
Feb 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,34 +90,44 @@ def __call__(self, event: dict, context: dict) -> dict:
stack.enter_context(lifespan_cycle)

request_context = event["requestContext"]
if "http" in request_context:

if event.get("multiValueHeaders"):
headers = {k.lower(): ", ".join(v) if isinstance(v, list) else ""
for k, v in event.get("multiValueHeaders", {}).items()}
elif event.get("headers"):
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
else:
headers = {}

# API Gateway v2
if event.get("version") == "2.0":
source_ip = request_context["http"]["sourceIp"]
path = request_context["http"]["path"]
http_method = request_context["http"]["method"]
query_string = event.get("rawQueryString", "").encode()

if event.get("cookies"):
headers["cookie"] = "; ".join(event.get("cookies", []))

# API Gateway v1 / ELB
else:
source_ip = request_context.get("identity", {}).get("sourceIp")
multi_value_query_string_params = event[
"multiValueQueryStringParameters"
]
query_string = (
urllib.parse.urlencode(
multi_value_query_string_params, doseq=True
).encode()
if multi_value_query_string_params
else b""
)
if "elb" in request_context:
# NOTE: trust only the most right side value
source_ip = headers.get("x-forwarded-for", "").split(", ")[-1]
else:
source_ip = request_context.get("identity", {}).get("sourceIp")

path = event["path"]
http_method = event["httpMethod"]

headers = (
{k.lower(): v for k, v in event.get("headers", {}).items()}
if event.get("headers")
else {}
)

if "cookies" in event:
headers["cookie"] = "; ".join(event.get("cookies", []))
if event.get("multiValueQueryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("multiValueQueryStringParameters", {}), doseq=True).encode()
elif event.get("queryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("queryStringParameters", {})).encode()
else:
query_string = b""

server_name = headers.get("host", "mangum")
if ":" not in server_name:
Expand Down
67 changes: 56 additions & 11 deletions mangum/protocols/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@
from mangum.exceptions import UnexpectedMessage


def all_casings(input_string):
"""
Permute all casings of a given string.
A pretty algoritm, via @Amber
http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python
"""
if not input_string:
yield ""
else:
first = input_string[:1]
if first.lower() == first.upper():
for sub_casing in all_casings(input_string[1:]):
yield first + sub_casing
else:
for sub_casing in all_casings(input_string[1:]):
yield first.lower() + sub_casing
yield first.upper() + sub_casing


class HTTPCycleState(enum.Enum):
"""
The state of the ASGI `http` connection.
Expand Down Expand Up @@ -116,21 +135,47 @@ async def send(self, message: Message) -> None:
self.response["statusCode"] = message["status"]
headers: typing.Dict[str, str] = {}
multi_value_headers: typing.Dict[str, typing.List[str]] = {}
for key, value in message.get("headers", []):
lower_key = key.decode().lower()
if lower_key in multi_value_headers:
multi_value_headers[lower_key].append(value.decode())
elif lower_key in headers:
multi_value_headers[lower_key] = [
headers.pop(lower_key),
value.decode(),
]
else:
headers[lower_key] = value.decode()
cookies: typing.List[str] = []
event = self.scope["aws.event"]
# ELB
if "elb" in event["requestContext"]:
for key, value in message.get("headers", []):
lower_key = key.decode().lower()
if lower_key in multi_value_headers:
multi_value_headers[lower_key].append(value.decode())
else:
multi_value_headers[lower_key] = [value.decode()]
if "multiValueHeaders" not in event:
# If there are multiple occurrences of headers, create case-mutated variations
# see: https://github.com/logandk/serverless-wsgi/issues/11
for key, values in multi_value_headers.items():
if len(values) > 1:
for value, cased_key in zip(values, all_casings(key)):
headers[cased_key] = value
elif len(values) == 1:
headers[key] = values[0]
multi_value_headers = {}
# API Gateway
else:
for key, value in message.get("headers", []):
lower_key = key.decode().lower()
if event.get("version") == "2.0" and lower_key == "set-cookie":
cookies.append(value.decode())
elif lower_key in multi_value_headers:
multi_value_headers[lower_key].append(value.decode())
elif lower_key in headers:
multi_value_headers[lower_key] = [
headers.pop(lower_key),
value.decode(),
]
else:
headers[lower_key] = value.decode()

self.response["headers"] = headers
if multi_value_headers:
self.response["multiValueHeaders"] = multi_value_headers
if len(cookies):
self.response["cookies"] = cookies
self.state = HTTPCycleState.RESPONSE

elif (
Expand Down
62 changes: 62 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,65 @@ def mock_http_api_event(request):
}

return event


@pytest.fixture
def mock_http_elb_singlevalue_event(request):
method = request.param[0]
body = request.param[1]
multi_value_query_parameters = request.param[2]
event = {
"requestContext": {
"elb": {
"targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:0:targetgroup/test/0"
}
},
"httpMethod": method,
"path": "/my/path",
"queryStringParameters": {
k: v[-1] for k, v in multi_value_query_parameters.items()
}
if multi_value_query_parameters
else None,
"headers": {
"accept-encoding": "gzip, deflate",
"cookie": "cookie1; cookie2",
"host": "test.execute-api.us-west-2.amazonaws.com",
"x-forwarded-for": "192.168.100.3, 192.168.100.2, 192.168.100.1",
"x-forwarded-port": "443",
"x-forwarded-proto": "https",
},
"body": body,
"isBase64Encoded": False
}

return event


@pytest.fixture
def mock_http_elb_multivalue_event(request):
method = request.param[0]
body = request.param[1]
multi_value_query_parameters = request.param[2]
event = {
"requestContext": {
"elb": {
"targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:0:targetgroup/test/0"
}
},
"httpMethod": method,
"path": "/my/path",
"multiValueQueryStringParameters": multi_value_query_parameters or None,
"multiValueHeaders": {
"accept-encoding": ["gzip, deflate"],
"cookie": ["cookie1; cookie2"],
"host": ["test.execute-api.us-west-2.amazonaws.com"],
"x-forwarded-for": ["192.168.100.3, 192.168.100.2, 192.168.100.1"],
"x-forwarded-port": ["443"],
"x-forwarded-proto": ["https"],
},
"body": body,
"isBase64Encoded": False
}

return event
Loading