diff --git a/CHANGELOG.md b/CHANGELOG.md index ce43446..5ee9dd3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -355,11 +355,12 @@ Released 2024-05-24 ### Version 0.7.0 -Release ETA 2024-02-31 ;) +Release ETA 2024-12-31 ;) +- Adds support for API GW's Web Socket Secure (WSS) protocol - Moves REST (API Gateway) related modules into one package - Moves exceptions to related packages -- Redesigns authorization from ground up -- Adds helpers and reduces jwt-related operations. - Removes deprecated pre and post requests hooks replaced by pre/post_handle in Resource +- Redesigns authorization from ground up - Fixes vulnerability `GHSA-wj6h-64fc-37mp` +- Adds helpers and reduces jwt-related operations. diff --git a/examples/rest/auth_example.py b/examples/rest/auth_example.py index 28b4bda..1e7b76d 100755 --- a/examples/rest/auth_example.py +++ b/examples/rest/auth_example.py @@ -1,4 +1,5 @@ """Simple Lambda Handler with authorization""" + from lbz.authz.decorators import authorization from lbz.dev.server import MyDevServer from lbz.exceptions import ServerError diff --git a/lbz/_request.py b/lbz/_request.py new file mode 100644 index 0000000..11be30f --- /dev/null +++ b/lbz/_request.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import base64 +import json +from typing import Any + +from lbz.authentication import User +from lbz.exceptions import BadRequestError + + +class Request: + def __init__( + self, + body: str | bytes | dict, + is_base64_encoded: bool, + context: dict, + user: User | None = None, + ) -> None: + + self.user = user + self.context = context + self._is_base64_encoded = is_base64_encoded + self._body = body + self._json_body: dict | None = None + self._raw_body: bytes | dict | None = None + + @property + def raw_body(self) -> bytes | dict | None: + if self._raw_body is None and self._body is not None: + if self._is_base64_encoded and isinstance(self._body, (bytes, str)): + self._raw_body = self._decode_base64(self._body) + elif isinstance(self._body, str): + self._raw_body = self._body.encode("utf-8") + else: + self._raw_body = self._body + return self._raw_body + + @staticmethod + def _decode_base64(encoded: str | bytes) -> bytes: + if not isinstance(encoded, bytes): + encoded = encoded.encode("ascii") + return base64.b64decode(encoded) + + @staticmethod + def _safe_json_loads(payload: str | bytes) -> Any: + try: + return json.loads(payload) + except ValueError as error: + raise BadRequestError(f"Invalid payload.\nPayload body:\n {repr(payload)}") from error diff --git a/lbz/authentication.py b/lbz/authentication.py index 3de9a78..fed6af8 100644 --- a/lbz/authentication.py +++ b/lbz/authentication.py @@ -1,4 +1,5 @@ """JWT based Authentication module.""" + from lbz._cfg import AUTH_REMOVE_PREFIXES from lbz.jwt_utils import decode_jwt diff --git a/lbz/exceptions.py b/lbz/exceptions.py index ec1b191..d0f5aa7 100644 --- a/lbz/exceptions.py +++ b/lbz/exceptions.py @@ -340,7 +340,7 @@ class NetworkAuthenticationRequired(LambdaFWServerException): def all_lbz_errors( cls: type[LambdaFWException] = LambdaFWException, -) -> Generator[type[LambdaFWException], None, None]: +) -> Generator[type[LambdaFWException]]: for subcls in cls.__subclasses__(): if subcls not in [LambdaFWClientException, LambdaFWServerException]: yield subcls diff --git a/lbz/misc.py b/lbz/misc.py index 4e3d73f..65c7d57 100644 --- a/lbz/misc.py +++ b/lbz/misc.py @@ -1,4 +1,5 @@ """Misc Helpers of Lambda Framework.""" + from __future__ import annotations import copy diff --git a/lbz/resource.py b/lbz/resource.py index b1ea57d..44e7232 100644 --- a/lbz/resource.py +++ b/lbz/resource.py @@ -19,9 +19,8 @@ UnsupportedMethod, ) from lbz.misc import get_logger, is_in_debug_mode -from lbz.request import Request from lbz.response import Response -from lbz.rest import ContentType +from lbz.rest import ContentType, HTTPRequest from lbz.router import Router ALLOW_ORIGIN_HEADER = "Access-Control-Allow-Origin" @@ -44,7 +43,7 @@ def __init__(self, event: dict): self.path_params = event.get("pathParameters") or {} # DO NOT refactor self.method = event["requestContext"]["httpMethod"] headers = CIMultiDict(event.get("headers", {})) - self.request = Request( + self.request = HTTPRequest( headers=headers, uri_params=self.path_params, method=self.method, diff --git a/lbz/rest/__init__.py b/lbz/rest/__init__.py index 3ada703..9918038 100644 --- a/lbz/rest/__init__.py +++ b/lbz/rest/__init__.py @@ -1,2 +1,3 @@ from lbz.rest.api_gateway_event import APIGatewayEvent from lbz.rest.enums import ContentType +from lbz.rest.request import HTTPRequest diff --git a/lbz/request.py b/lbz/rest/request.py similarity index 60% rename from lbz/request.py rename to lbz/rest/request.py index 25ff961..373823d 100644 --- a/lbz/request.py +++ b/lbz/rest/request.py @@ -1,24 +1,22 @@ from __future__ import annotations -import base64 -import json -from typing import Any - from multidict import CIMultiDict +from lbz._request import Request from lbz.authentication import User from lbz.exceptions import BadRequestError from lbz.misc import MultiDict, get_logger -from lbz.rest import ContentType +from lbz.rest.enums import ContentType logger = get_logger(__name__) -class Request: - """Represents request from API gateway.""" +class HTTPRequest(Request): + """Represents request from HTTP API Gateway.""" def __init__( self, + *, headers: CIMultiDict, uri_params: dict, method: str, @@ -29,45 +27,22 @@ def __init__( query_params: dict | None = None, user: User | None = None, ): - self.query_params = MultiDict(query_params or {}) + super().__init__( + body=body, + is_base64_encoded=is_base64_encoded, + context=context, + user=user, + ) self.headers = headers + self.query_params = MultiDict(query_params or {}) self.uri_params = uri_params self.method = method self.context = context self.stage_vars = stage_vars - self.user = user - self._is_base64_encoded = is_base64_encoded - self._body = body - self._json_body: dict | None = None - self._raw_body: bytes | dict | None = None def __repr__(self) -> str: return f"" - @staticmethod - def _decode_base64(encoded: str | bytes) -> bytes: - if not isinstance(encoded, bytes): - encoded = encoded.encode("ascii") - return base64.b64decode(encoded) - - @property - def raw_body(self) -> bytes | dict | None: - if self._raw_body is None and self._body is not None: - if self._is_base64_encoded and isinstance(self._body, (bytes, str)): - self._raw_body = self._decode_base64(self._body) - elif isinstance(self._body, str): - self._raw_body = self._body.encode("utf-8") - else: - self._raw_body = self._body - return self._raw_body - - @staticmethod - def _safe_json_loads(payload: str | bytes) -> Any: - try: - return json.loads(payload) - except ValueError as error: - raise BadRequestError(f"Invalid payload.\nPayload body:\n {repr(payload)}") from error - @property def json_body(self) -> dict | None: if self._json_body is None: diff --git a/lbz/type_defs.py b/lbz/type_defs.py index 460d5b1..7f15b38 100644 --- a/lbz/type_defs.py +++ b/lbz/type_defs.py @@ -2,6 +2,7 @@ https://docs.aws.amazon.com/lambda/latest/dg/python-context.html """ + from typing import Any diff --git a/lbz/websocket/__init__.py b/lbz/websocket/__init__.py new file mode 100644 index 0000000..8028435 --- /dev/null +++ b/lbz/websocket/__init__.py @@ -0,0 +1,2 @@ +from lbz.websocket.enums import ActionType +from lbz.websocket.request import WebSocketRequest diff --git a/lbz/websocket/enums.py b/lbz/websocket/enums.py new file mode 100644 index 0000000..ba8bd70 --- /dev/null +++ b/lbz/websocket/enums.py @@ -0,0 +1,4 @@ +class ActionType: + CONNECT = "CONNECT" + DISCONNECT = "DISCONNECT" + MESSAGE = "MESSAGE" diff --git a/lbz/websocket/request.py b/lbz/websocket/request.py new file mode 100644 index 0000000..8d920bf --- /dev/null +++ b/lbz/websocket/request.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from multidict import CIMultiDict + +from lbz._request import Request +from lbz.authentication import User +from lbz.websocket.enums import ActionType + + +class WebSocketRequest(Request): + """Represents request from Web Socket Secure API Gateway. + + https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-websocket-api-mapping-template-reference.html + """ + + def __init__( + self, + *, + body: str | bytes | dict, + request_details: dict, + context: dict, + is_base64_encoded: bool, + user: User | None = None, + headers: CIMultiDict | None = None, + ) -> None: + super().__init__( + body=body, + is_base64_encoded=is_base64_encoded, + context=context, + user=user, + ) + self.headers = headers + self.action = request_details.pop("routeKey") + self.action_type = request_details.pop("eventType") + self.connection_id = request_details.pop("connectionId") + self.domain = request_details.pop("domainName") + self.stage = request_details.pop("stage") + self.details = request_details + + def __repr__(self) -> str: + return f"" + + @property + def json_body(self) -> dict | None: + if self._json_body is None: + if isinstance(self.raw_body, dict) or self.raw_body is None: + self._json_body = self.raw_body + else: + self._json_body = self._safe_json_loads(self.raw_body) + return self._json_body + + def is_connection_request(self) -> bool: + return self.action_type is ActionType.CONNECT + + def is_disconnection_request(self) -> bool: + return self.action_type is ActionType.DISCONNECT diff --git a/requirements-dev.txt b/requirements-dev.txt index 7790c50..363fb12 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,57 +4,53 @@ # # pip-compile requirements-dev.in # -astroid==3.0.2 +astroid==3.3.5 # via pylint -bandit==1.7.6 +bandit==1.7.10 # via -r requirements-dev.in -black==23.12.1 +black==24.10.0 # via -r requirements-dev.in boolean-py==4.0 # via license-expression -boto3-stubs[cognito-idp,dynamodb,events,lambda,s3,sns,sqs,ssm]==1.34.11 +boto3-stubs[cognito-idp,dynamodb,events,lambda,s3,sns,sqs,ssm]==1.34.158 # via -r requirements-dev.in -botocore-stubs==1.34.11 +botocore-stubs==1.34.158 # via boto3-stubs -build==1.0.3 +build==1.2.2.post1 # via pip-tools -cachecontrol[filecache]==0.13.1 +cachecontrol[filecache]==0.14.1 # via # cachecontrol # pip-audit -certifi==2023.11.17 +certifi==2024.8.30 # via requests -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 # via requests click==8.1.7 # via # black # pip-tools -coverage[toml]==7.4.0 +coverage[toml]==7.6.5 # via # -r requirements-dev.in # pytest-cov -cyclonedx-python-lib==5.2.0 +cyclonedx-python-lib==7.6.2 # via pip-audit defusedxml==0.7.1 # via py-serializable -dill==0.3.7 +dill==0.3.9 # via pylint -exceptiongroup==1.2.0 +exceptiongroup==1.2.2 # via pytest -filelock==3.13.1 +filelock==3.16.1 # via cachecontrol -flake8==6.1.0 +flake8==7.1.1 # via -r requirements-dev.in -gitdb==4.0.11 - # via gitpython -gitpython==3.1.40 - # via bandit html5lib==1.1 # via pip-audit -idna==3.6 +idna==3.10 # via requests -importlib-metadata==7.0.1 +importlib-metadata==8.5.0 # via build iniconfig==2.0.0 # via pytest @@ -62,7 +58,7 @@ isort==5.13.2 # via # -r requirements-dev.in # pylint -license-expression==30.2.0 +license-expression==30.4.0 # via cyclonedx-python-lib markdown-it-py==3.0.0 # via rich @@ -72,33 +68,33 @@ mccabe==0.7.0 # pylint mdurl==0.1.2 # via markdown-it-py -msgpack==1.0.7 +msgpack==1.1.0 # via cachecontrol -mypy==1.8.0 +mypy==1.13.0 # via -r requirements-dev.in -mypy-boto3-cognito-idp==1.34.3 +mypy-boto3-cognito-idp==1.34.158 # via boto3-stubs -mypy-boto3-dynamodb==1.34.0 +mypy-boto3-dynamodb==1.34.148 # via boto3-stubs -mypy-boto3-events==1.34.0 +mypy-boto3-events==1.34.151 # via boto3-stubs -mypy-boto3-lambda==1.34.0 +mypy-boto3-lambda==1.34.77 # via boto3-stubs -mypy-boto3-s3==1.34.0 +mypy-boto3-s3==1.34.162 # via boto3-stubs -mypy-boto3-sns==1.34.0 +mypy-boto3-sns==1.34.121 # via boto3-stubs -mypy-boto3-sqs==1.34.0 +mypy-boto3-sqs==1.34.121 # via boto3-stubs -mypy-boto3-ssm==1.34.0 +mypy-boto3-ssm==1.34.158 # via boto3-stubs mypy-extensions==1.0.0 # via # black # mypy -packageurl-python==0.13.1 +packageurl-python==0.16.0 # via cyclonedx-python-lib -packaging==23.2 +packaging==24.2 # via # black # build @@ -107,52 +103,54 @@ packaging==23.2 # pytest pathspec==0.12.1 # via black -pbr==6.0.0 +pbr==6.1.0 # via stevedore -pip-api==0.0.30 +pip-api==0.0.34 # via pip-audit -pip-audit==2.6.2 +pip-audit==2.7.3 # via -r requirements-dev.in pip-requirements-parser==32.0.1 # via pip-audit -pip-tools==7.3.0 +pip-tools==7.4.1 # via -r requirements-dev.in -platformdirs==4.1.0 +platformdirs==4.3.6 # via # black # pylint -pluggy==1.3.0 +pluggy==1.5.0 # via pytest -py-serializable==0.15.0 +py-serializable==1.1.2 # via cyclonedx-python-lib -pycodestyle==2.11.1 +pycodestyle==2.12.1 # via flake8 -pyflakes==3.1.0 +pyflakes==3.2.0 # via flake8 -pygments==2.17.2 +pygments==2.18.0 # via rich -pylint==3.0.3 +pylint==3.3.1 # via -r requirements-dev.in -pyparsing==3.1.1 +pyparsing==3.2.0 # via pip-requirements-parser -pyproject-hooks==1.0.0 - # via build -pytest==7.4.4 +pyproject-hooks==1.2.0 + # via + # build + # pip-tools +pytest==8.3.3 # via # -r requirements-dev.in # pytest-cov # pytest-mock -pytest-cov==4.1.0 +pytest-cov==6.0.0 # via -r requirements-dev.in -pytest-mock==3.12.0 +pytest-mock==3.14.0 # via -r requirements-dev.in -pyyaml==6.0.1 +pyyaml==6.0.2 # via bandit -requests==2.31.0 +requests==2.32.3 # via # cachecontrol # pip-audit -rich==13.7.0 +rich==13.9.4 # via # bandit # pip-audit @@ -160,15 +158,13 @@ six==1.16.0 # via # -c requirements.txt # html5lib -smmap==5.0.1 - # via gitdb sortedcontainers==2.4.0 # via cyclonedx-python-lib -stevedore==5.1.0 +stevedore==5.3.0 # via bandit toml==0.10.2 # via pip-audit -tomli==2.0.1 +tomli==2.1.0 # via # black # build @@ -176,16 +172,16 @@ tomli==2.0.1 # mypy # pip-tools # pylint - # pyproject-hooks # pytest -tomlkit==0.12.3 +tomlkit==0.13.2 # via pylint -types-awscrt==0.20.0 +types-awscrt==0.23.0 # via botocore-stubs -types-s3transfer==0.10.0 +types-s3transfer==0.10.3 # via boto3-stubs -typing-extensions==4.9.0 +typing-extensions==4.12.2 # via + # -c requirements.txt # astroid # black # boto3-stubs @@ -199,15 +195,16 @@ typing-extensions==4.9.0 # mypy-boto3-sqs # mypy-boto3-ssm # pylint -urllib3==1.26.18 + # rich +urllib3==1.26.20 # via # -c requirements.txt # requests webencodings==0.5.1 # via html5lib -wheel==0.42.0 +wheel==0.45.0 # via pip-tools -zipp==3.17.0 +zipp==3.21.0 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements.txt b/requirements.txt index 37bdb98..7bb9d52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,35 +4,37 @@ # # pip-compile # -boto3==1.34.11 +boto3==1.34.162 # via lbz (setup.py) -botocore==1.34.11 +botocore==1.34.162 # via # boto3 # s3transfer -ecdsa==0.18.0 +ecdsa==0.19.0 # via python-jose jmespath==1.0.1 # via # boto3 # botocore -multidict==6.0.4 +multidict==6.1.0 # via lbz (setup.py) -pyasn1==0.5.1 +pyasn1==0.6.1 # via # python-jose # rsa -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via botocore python-jose==3.3.0 # via lbz (setup.py) rsa==4.9 # via python-jose -s3transfer==0.10.0 +s3transfer==0.10.3 # via boto3 six==1.16.0 # via # ecdsa # python-dateutil -urllib3==1.26.18 +typing-extensions==4.12.2 + # via multidict +urllib3==1.26.20 # via botocore diff --git a/setup.cfg b/setup.cfg index 0dcc17c..6010062 100644 --- a/setup.cfg +++ b/setup.cfg @@ -74,6 +74,7 @@ disable = # library specific too-many-instance-attributes, too-many-arguments, + too-many-positional-arguments, jobs = 3 load-plugins = pylint.extensions.bad_builtin, diff --git a/setup.py b/setup.py index 720c875..e0ae2bb 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ long_description=pathlib.Path("README.md").read_text("utf-8"), install_requires=[ "boto3>=1.34.11,<1.35.0", - "multidict>=6.0.4,<6.1.0", + "multidict>=6.1.0,<6.2.0", "python-jose>=3.3.0,<3.4.0", ], classifiers=[ diff --git a/tests/conftest.py b/tests/conftest.py index df444c8..4668dcf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,10 +24,9 @@ from lbz.authz.authorizer import Authorizer from lbz.authz.decorators import authorization from lbz.collector import authz_collector -from lbz.request import Request from lbz.resource import Resource from lbz.response import Response -from lbz.rest import APIGatewayEvent, ContentType +from lbz.rest import APIGatewayEvent, ContentType, HTTPRequest from lbz.router import Router, add_route from tests.fixtures.rsa_pair import SAMPLE_PRIVATE_KEY, SAMPLE_PUBLIC_KEY from tests.utils import encode_token @@ -76,9 +75,9 @@ def clear_router_collector() -> Iterator[None]: @pytest.fixture() -def sample_request() -> Request: +def sample_request() -> HTTPRequest: # TODO: change to simple factory / parametrise it - return Request( + return HTTPRequest( method="GET", body="", headers=CIMultiDict({"Content-Type": ContentType.JSON}), @@ -177,8 +176,8 @@ def user_fixture(user_token: str) -> User: @pytest.fixture() -def sample_request_with_user(user: User) -> Request: - return Request( +def sample_request_with_user(user: User) -> HTTPRequest: + return HTTPRequest( method="GET", body="", headers=CIMultiDict({"Content-Type": ContentType.JSON}), diff --git a/tests/test_resource.py b/tests/test_resource.py index f6ffac4..a8ef2dc 100644 --- a/tests/test_resource.py +++ b/tests/test_resource.py @@ -18,7 +18,6 @@ from lbz.events.api import EventAPI from lbz.exceptions import NotFound, ServerError from lbz.misc import MultiDict -from lbz.request import Request from lbz.resource import ( ALLOW_ORIGIN_HEADER, CORSResource, @@ -27,13 +26,13 @@ Resource, ) from lbz.response import Response -from lbz.rest import APIGatewayEvent, ContentType +from lbz.rest import APIGatewayEvent, ContentType, HTTPRequest from lbz.router import Router, add_route from tests.fixtures.rsa_pair import SAMPLE_PUBLIC_KEY # TODO: Use fixtures yielded from conftest.py -req = Request( +req = HTTPRequest( body="", headers=CIMultiDict({"Content-Type": ContentType.JSON}), uri_params={}, @@ -75,7 +74,7 @@ def test___init__(self) -> None: assert isinstance(self.res.method, str) assert self.res.method == "GET" assert self.res.path_params == {} - assert isinstance(self.res.request, Request) + assert isinstance(self.res.request, HTTPRequest) assert self.res._router is not None # pylint: disable=protected-access assert isinstance(self.res._router, Router) # pylint: disable=protected-access assert self.res.request.user is None diff --git a/tests/test_request.py b/tests/test_rest/test_request.py similarity index 76% rename from tests/test_request.py rename to tests/test_rest/test_request.py index 36c10e4..2de0aff 100644 --- a/tests/test_request.py +++ b/tests/test_rest/test_request.py @@ -1,17 +1,14 @@ -# coding=utf-8 - import pytest from multidict import CIMultiDict from lbz.exceptions import BadRequestError from lbz.misc import MultiDict -from lbz.request import Request -from lbz.rest import ContentType +from lbz.rest import ContentType, HTTPRequest -class TestRequestInit: +class TestHTTPRequest: def test__init__(self) -> None: - req = Request( + req = HTTPRequest( uri_params={}, method="", body="", @@ -33,92 +30,84 @@ def test__init__(self) -> None: assert isinstance(req._is_base64_encoded, bool) # pylint: disable=protected-access assert req.user is None - -class TestRequestGeneral: - def test___repr__(self, sample_request: Request) -> None: + def test___repr__(self, sample_request: HTTPRequest) -> None: assert str(sample_request) == f"" def test__decode_base64_bytes(self) -> None: encoded = b"asdasdsd" - output = Request._decode_base64(encoded) # pylint: disable=protected-access + output = HTTPRequest._decode_base64(encoded) # pylint: disable=protected-access assert output == b"j\xc7Z\xb1\xdb\x1d" def test__decode_base64_str(self) -> None: encoded = "asdasdsd" - output = Request._decode_base64(encoded) # pylint: disable=protected-access + output = HTTPRequest._decode_base64(encoded) # pylint: disable=protected-access assert output == b"j\xc7Z\xb1\xdb\x1d" - def test_accessing_user_attributes(self, sample_request_with_user: Request) -> None: - assert isinstance(sample_request_with_user.user.username, str) # type: ignore - assert isinstance(sample_request_with_user.user.email, str) # type: ignore - for custom_param in range(1, 5): - assert isinstance(getattr(sample_request_with_user.user, str(custom_param)), str) - - def test_headers_are_case_insensitive(self, sample_request_with_user: Request) -> None: - assert sample_request_with_user.headers["content-type"] == ContentType.JSON - assert sample_request_with_user.headers["CoNtEnT-TyPe"] == ContentType.JSON - - -class TestRequestRawBody: - def test_raw_body_base64_bytes(self, sample_request: Request) -> None: + def test_raw_body_base64_bytes(self, sample_request: HTTPRequest) -> None: sample_request._body = b"asdasdsd" # pylint: disable=protected-access sample_request._is_base64_encoded = True # pylint: disable=protected-access assert sample_request.raw_body == b"j\xc7Z\xb1\xdb\x1d" - def test_raw_body_base64_str(self, sample_request: Request) -> None: + def test_raw_body_base64_str(self, sample_request: HTTPRequest) -> None: sample_request._body = "asdasdsd" # pylint: disable=protected-access sample_request._is_base64_encoded = True # pylint: disable=protected-access assert sample_request.raw_body == b"j\xc7Z\xb1\xdb\x1d" - def test_raw_body_bytes(self, sample_request: Request) -> None: + def test_raw_body_bytes(self, sample_request: HTTPRequest) -> None: sample_request._body = b"asdasdsd" # pylint: disable=protected-access assert sample_request.raw_body == b"asdasdsd" - def test_raw_body_str(self, sample_request: Request) -> None: + def test_raw_body_str(self, sample_request: HTTPRequest) -> None: sample_request._body = "abcx" # pylint: disable=protected-access assert sample_request.raw_body == b"abcx" + def test_accessing_user_attributes(self, sample_request_with_user: HTTPRequest) -> None: + assert isinstance(sample_request_with_user.user.username, str) # type: ignore + assert isinstance(sample_request_with_user.user.email, str) # type: ignore + for custom_param in range(1, 5): + assert isinstance(getattr(sample_request_with_user.user, str(custom_param)), str) -class TestRequestJsonBody: - def test_json_body_dict(self, sample_request: Request) -> None: + def test_headers_are_case_insensitive(self, sample_request_with_user: HTTPRequest) -> None: + assert sample_request_with_user.headers["content-type"] == ContentType.JSON + assert sample_request_with_user.headers["CoNtEnT-TyPe"] == ContentType.JSON + + def test_json_body_dict(self, sample_request: HTTPRequest) -> None: sample_request._body = {"x": "t1"} # pylint: disable=protected-access assert sample_request.json_body == {"x": "t1"} - def test_json_body_json(self, sample_request: Request) -> None: + def test_json_body_json(self, sample_request: HTTPRequest) -> None: sample_request._json_body = None # pylint: disable=protected-access sample_request._body = '{"x": "t2"}' # pylint: disable=protected-access assert sample_request.json_body == {"x": "t2"} assert sample_request._json_body == {"x": "t2"} # pylint: disable=protected-access - def test_json_body_base64_json(self, sample_request: Request) -> None: + def test_json_body_base64_json(self, sample_request: HTTPRequest) -> None: sample_request._is_base64_encoded = True # pylint: disable=protected-access sample_request._body = b"eyJ4IjogImFiY3gifQ==" # pylint: disable=protected-access assert sample_request.json_body == {"x": "abcx"} assert sample_request._json_body == {"x": "abcx"} # pylint: disable=protected-access - def test_json_body_bad_json(self, sample_request: Request) -> None: + def test_json_body_bad_json(self, sample_request: HTTPRequest) -> None: sample_request._json_body = None # pylint: disable=protected-access sample_request._body = '{"x": t4}' # pylint: disable=protected-access with pytest.raises(BadRequestError): sample_request.json_body # pylint: disable=pointless-statement - def test_json_body_bad_content_type(self, sample_request: Request) -> None: + def test_json_body_bad_content_type(self, sample_request: HTTPRequest) -> None: sample_request.headers = CIMultiDict({"Content-Type": "application/dzejson"}) with pytest.raises(BadRequestError) as err: sample_request.json_body # pylint: disable=pointless-statement - assert err.value.message.startswith("Content-Type header is missing or wrong") + assert err.value.message.startswith("Content-Type header is missing or wrong") - def test_json_body_none_when_no_content_type(self, sample_request: Request) -> None: + def test_json_body_none_when_no_content_type(self, sample_request: HTTPRequest) -> None: sample_request.headers = CIMultiDict({}) assert sample_request.json_body is None - def test_json_body_none_as_body(self, sample_request: Request) -> None: + def test_json_body_none_as_body(self, sample_request: HTTPRequest) -> None: sample_request._body = None # type: ignore # pylint: disable=protected-access assert sample_request.json_body is None - -class TestRequestToDict: - def test_to_dict(self, sample_request_with_user: Request) -> None: + def test_to_dict(self, sample_request_with_user: HTTPRequest) -> None: assert sample_request_with_user.to_dict() == { "context": {}, "headers": {"Content-Type": ContentType.JSON}, diff --git a/tests/test_websocket/test_request.py b/tests/test_websocket/test_request.py new file mode 100644 index 0000000..7c5dc2d --- /dev/null +++ b/tests/test_websocket/test_request.py @@ -0,0 +1,119 @@ +import pytest + +from lbz.exceptions import BadRequestError +from lbz.websocket import ActionType, WebSocketRequest + + +class TestWebSocketRequest: + @pytest.mark.parametrize( + "action_type, expected_result", + [ + (ActionType.CONNECT, True), + (ActionType.DISCONNECT, False), + (ActionType.MESSAGE, False), + ], + ) + def test__is_connection_request__checks_if_request_was_a_connection_request( + self, action_type: str, expected_result: bool + ) -> None: + request = WebSocketRequest( + body="", + request_details={ + "routeKey": "$connect", + "eventType": action_type, + "connectionId": "aaa123", + "domainName": "xxx.com", + "stage": "prod", + }, + context={}, + is_base64_encoded=False, + ) + + assert request.is_connection_request() == expected_result + + @pytest.mark.parametrize( + "action_type, expected_result", + [ + (ActionType.CONNECT, False), + (ActionType.DISCONNECT, True), + (ActionType.MESSAGE, False), + ], + ) + def test__is_disconnection_request__checks_if_request_was_a_disconnection_request( + self, action_type: str, expected_result: bool + ) -> None: + request = WebSocketRequest( + body="", + request_details={ + "routeKey": "$connect", + "eventType": action_type, + "connectionId": "aaa123", + "domainName": "xxx.com", + "stage": "prod", + }, + context={}, + is_base64_encoded=False, + ) + + assert request.is_disconnection_request() == expected_result + + def test__json_body__returns_dict_when_body_is_a_dict(self) -> None: + request = WebSocketRequest( + body={"x": "t1"}, + request_details={ + "routeKey": "send_message", + "eventType": ActionType.MESSAGE, + "connectionId": "aaa123", + "domainName": "xxx.com", + "stage": "prod", + }, + context={}, + is_base64_encoded=False, + ) + assert request.json_body == {"x": "t1"} + + def test__json_body___returns_dict_when_body_is_a_str(self) -> None: + request = WebSocketRequest( + body='{"x": "t1"}', + request_details={ + "routeKey": "send_message", + "eventType": ActionType.MESSAGE, + "connectionId": "aaa123", + "domainName": "xxx.com", + "stage": "prod", + }, + context={}, + is_base64_encoded=False, + ) + assert request.json_body == {"x": "t1"} + + def test__json_body___handles_none_as_none(self) -> None: + request = WebSocketRequest( + body=None, # type: ignore + request_details={ + "routeKey": "send_message", + "eventType": ActionType.MESSAGE, + "connectionId": "aaa123", + "domainName": "xxx.com", + "stage": "prod", + }, + context={}, + is_base64_encoded=False, + ) + assert request.json_body is None + + def test__json_body___malformated_json(self) -> None: + request = WebSocketRequest( + body="[]: []", + request_details={ + "routeKey": "send_message", + "eventType": ActionType.MESSAGE, + "connectionId": "aaa123", + "domainName": "xxx.com", + "stage": "prod", + }, + context={}, + is_base64_encoded=False, + ) + with pytest.raises(BadRequestError): + request.json_body # pylint: disable=pointless-statement diff --git a/version b/version index 05e8a45..faef31a 100644 --- a/version +++ b/version @@ -1 +1 @@ -0.6.6 +0.7.0