Skip to content

Commit

Permalink
fix: add missing fields to RequestContext for local lambda invoke (aw…
Browse files Browse the repository at this point in the history
…s#2060)

Co-authored-by: Tarun <c2tarun@users.noreply.github.com>
  • Loading branch information
mndeveci and c2tarun authored Aug 10, 2020
1 parent 7c56e4e commit e8cae9d
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 6 deletions.
10 changes: 9 additions & 1 deletion samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def _construct_event(flask_request, port, binary_types, stage_name=None, stage_v

endpoint = PathConverter.convert_path_to_api_gateway(flask_request.endpoint)
method = flask_request.method
protocol = flask_request.environ.get("SERVER_PROTOCOL", "HTTP/1.1")
host = flask_request.host

request_data = flask_request.get_data()

Expand All @@ -409,7 +411,13 @@ def _construct_event(flask_request, port, binary_types, stage_name=None, stage_v
request_data = request_data.decode("utf-8")

context = RequestContext(
resource_path=endpoint, http_method=method, stage=stage_name, identity=identity, path=endpoint
resource_path=endpoint,
http_method=method,
stage=stage_name,
identity=identity,
path=endpoint,
protocol=protocol,
domain_name=host,
)

headers_dict, multi_value_headers_dict = LocalApigwService._event_headers(flask_request, port)
Expand Down
14 changes: 14 additions & 0 deletions samcli/local/events/api_event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Holds Classes for API Gateway to Lambda Events"""
from time import time
from datetime import datetime


class ContextIdentity:
Expand Down Expand Up @@ -75,6 +77,10 @@ def __init__(
identity=None,
extended_request_id=None,
path=None,
protocol=None,
domain_name=None,
request_time_epoch=int(time()),
request_time=datetime.utcnow().strftime("%d/%b/%Y:%H:%M:%S +0000"),
):
"""
Constructs a RequestContext
Expand All @@ -101,6 +107,10 @@ def __init__(
self.identity = identity
self.extended_request_id = extended_request_id
self.path = path
self.protocol = protocol
self.domain_name = domain_name
self.request_time_epoch = request_time_epoch
self.request_time = request_time

def to_dict(self):
"""
Expand All @@ -123,6 +133,10 @@ def to_dict(self):
"identity": identity_dict,
"extendedRequestId": self.extended_request_id,
"path": self.path,
"protocol": self.protocol,
"domainName": self.domain_name,
"requestTimeEpoch": self.request_time_epoch,
"requestTime": self.request_time,
}

return json_dict
Expand Down
33 changes: 29 additions & 4 deletions tests/unit/local/apigw/test_local_apigw_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import copy
import json
from datetime import datetime
from unittest import TestCase

from unittest.mock import Mock, patch, ANY, MagicMock
Expand Down Expand Up @@ -513,6 +514,7 @@ def setUp(self):
self.request_mock.path = "path"
self.request_mock.method = "GET"
self.request_mock.remote_addr = "190.0.0.0"
self.request_mock.host = "190.0.0.1"
self.request_mock.get_data.return_value = b"DATA!!!!"
query_param_args_mock = Mock()
query_param_args_mock.lists.return_value = {"query": ["params"]}.items()
Expand All @@ -524,6 +526,8 @@ def setUp(self):
self.request_mock.headers = headers_mock
self.request_mock.view_args = {"path": "params"}
self.request_mock.scheme = "http"
environ_dict = {"SERVER_PROTOCOL": "HTTP/1.1"}
self.request_mock.environ = environ_dict

expected = (
'{"body": "DATA!!!!", "httpMethod": "GET", '
Expand All @@ -535,7 +539,8 @@ def setUp(self):
'"identity": {"accountId": null, "apiKey": null, "userArn": null, '
'"cognitoAuthenticationProvider": null, "cognitoIdentityPoolId": null, "userAgent": '
'"Custom User Agent String", "caller": null, "cognitoAuthenticationType": null, "sourceIp": '
'"190.0.0.0", "user": null}, "accountId": "123456789012"}, "headers": {"Content-Type": '
'"190.0.0.0", "user": null}, "accountId": "123456789012", "domainName": "190.0.0.1", '
'"protocol": "HTTP/1.1"}, "headers": {"Content-Type": '
'"application/json", "X-Test": "Value", "X-Forwarded-Port": "3000", "X-Forwarded-Proto": "http"}, '
'"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], '
'"X-Forwarded-Port": ["3000"], "X-Forwarded-Proto": ["http"]}, '
Expand All @@ -545,16 +550,33 @@ def setUp(self):

self.expected_dict = json.loads(expected)

def validate_request_context_and_remove_request_time_data(self, event_json):
request_time = event_json["requestContext"].pop("requestTime", None)
request_time_epoch = event_json["requestContext"].pop("requestTimeEpoch", None)

self.assertIsInstance(request_time, str)
parsed_request_time = datetime.strptime(request_time, "%d/%b/%Y:%H:%M:%S +0000")
self.assertIsInstance(parsed_request_time, datetime)

self.assertIsInstance(request_time_epoch, int)

def test_construct_event_with_data(self):
actual_event_str = LocalApigwService._construct_event(self.request_mock, 3000, binary_types=[])
self.assertEqual(json.loads(actual_event_str), self.expected_dict)

actual_event_json = json.loads(actual_event_str)
self.validate_request_context_and_remove_request_time_data(actual_event_json)

self.assertEqual(actual_event_json, self.expected_dict)

def test_construct_event_no_data(self):
self.request_mock.get_data.return_value = None
self.expected_dict["body"] = None

actual_event_str = LocalApigwService._construct_event(self.request_mock, 3000, binary_types=[])
self.assertEqual(json.loads(actual_event_str), self.expected_dict)
actual_event_json = json.loads(actual_event_str)
self.validate_request_context_and_remove_request_time_data(actual_event_json)

self.assertEqual(actual_event_json, self.expected_dict)

@patch("samcli.local.apigw.local_apigw_service.LocalApigwService._should_base64_encode")
def test_construct_event_with_binary_data(self, should_base64_encode_patch):
Expand All @@ -569,7 +591,10 @@ def test_construct_event_with_binary_data(self, should_base64_encode_patch):
self.maxDiff = None

actual_event_str = LocalApigwService._construct_event(self.request_mock, 3000, binary_types=[])
self.assertEqual(json.loads(actual_event_str), self.expected_dict)
actual_event_json = json.loads(actual_event_str)
self.validate_request_context_and_remove_request_time_data(actual_event_json)

self.assertEqual(actual_event_json, self.expected_dict)

def test_event_headers_with_empty_list(self):
request_mock = Mock()
Expand Down
22 changes: 21 additions & 1 deletion tests/unit/local/events/test_api_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def test_class_initialized(self):
identity_mock,
"extended_request_id",
"path",
"protocol",
"domain_name",
"request_time_epoch",
"request_time",
)

self.assertEqual(request_context.resource_id, "resource_id")
Expand All @@ -105,6 +109,10 @@ def test_class_initialized(self):
self.assertEqual(request_context.identity, identity_mock)
self.assertEqual(request_context.extended_request_id, "extended_request_id")
self.assertEqual(request_context.path, "path")
self.assertEqual(request_context.protocol, "protocol")
self.assertEqual(request_context.domain_name, "domain_name")
self.assertEqual(request_context.request_time_epoch, "request_time_epoch")
self.assertEqual(request_context.request_time, "request_time")

def test_to_dict(self):
identity_mock = Mock()
Expand All @@ -121,6 +129,10 @@ def test_to_dict(self):
identity_mock,
"extended_request_id",
"path",
"protocol",
"domain_name",
"request_time_epoch",
"request_time",
)

expected = {
Expand All @@ -134,12 +146,16 @@ def test_to_dict(self):
"identity": {"identity": "the identity"},
"extendedRequestId": "extended_request_id",
"path": "path",
"protocol": "protocol",
"domainName": "domain_name",
"requestTimeEpoch": "request_time_epoch",
"requestTime": "request_time",
}

self.assertEqual(request_context.to_dict(), expected)

def test_to_dict_with_defaults(self):
request_context = RequestContext()
request_context = RequestContext(request_time="request_time", request_time_epoch="request_time_epoch")

expected = {
"resourceId": "123456",
Expand All @@ -152,6 +168,10 @@ def test_to_dict_with_defaults(self):
"identity": {},
"extendedRequestId": None,
"path": None,
"protocol": None,
"domainName": None,
"requestTimeEpoch": "request_time_epoch",
"requestTime": "request_time",
}

self.assertEqual(request_context.to_dict(), expected)
Expand Down

0 comments on commit e8cae9d

Please sign in to comment.