diff --git a/polaris/tests/test_utils.py b/polaris/tests/test_utils.py index 2af62284..92d829dc 100644 --- a/polaris/tests/test_utils.py +++ b/polaris/tests/test_utils.py @@ -1,3 +1,5 @@ +import base64 +import json import pytest from unittest.mock import patch, Mock from secrets import token_bytes @@ -99,33 +101,56 @@ def test_memo_str_hash_memo(): Transaction.MEMO_TYPES.hash, ) +def test_compute_callback_signature(): + callback_url = "https://testanchor.stellar.org/sep24/callbacks" + callback_body = json.dumps({"a": "b"}) + signature_header = utils.compute_callback_signature(callback_url, callback_body) + t, s = signature_header.split(", ") + assert t + assert s + t_key, t_val = t.split("=") + assert t_key == "t" + timestamp = int(t_val) + s_key, s_val = s.split("=", 1) + assert s_key == "s" + signature_payload = f"{timestamp}.testanchor.stellar.org.{callback_body}" + Keypair.from_public_key(settings.SIGNING_KEY).verify(signature_payload.encode(), base64.b64decode(s_val)) @patch(f"{test_module}.post") @patch(f"{test_module}.TransactionSerializer") +@patch(f"{test_module}.json", Mock(dumps=Mock(return_value="{}"))) +@patch(f"{test_module}.compute_callback_signature", Mock(return_value="test")) def test_make_on_change_callback_success(mock_serializer, mock_post): mock_transaction = Mock(on_change_callback="test") utils.make_on_change_callback(mock_transaction) mock_serializer.assert_called_once_with(mock_transaction) mock_post.assert_called_once_with( url=mock_transaction.on_change_callback, - json=mock_serializer(mock_transaction).data, + json={"transaction": mock_serializer(mock_transaction).data}, timeout=settings.CALLBACK_REQUEST_TIMEOUT, + headers={ + "Signature": "test" + } ) @patch(f"{test_module}.post") @patch(f"{test_module}.TransactionSerializer") +@patch(f"{test_module}.json", Mock(dumps=Mock(return_value="{}"))) +@patch(f"{test_module}.compute_callback_signature", Mock(return_value="test")) def test_make_on_change_callback_success_with_timeout(mock_serializer, mock_post): mock_transaction = Mock(on_change_callback="test") utils.make_on_change_callback(mock_transaction, timeout=5) mock_serializer.assert_called_once_with(mock_transaction) mock_post.assert_called_once_with( url=mock_transaction.on_change_callback, - json=mock_serializer(mock_transaction).data, + json={"transaction": mock_serializer(mock_transaction).data}, timeout=5, + headers={ + "Signature": "test" + } ) - @patch(f"{test_module}.post") @patch(f"{test_module}.TransactionSerializer") def test_make_on_change_callback_raises_valueerror_for_postmessage( diff --git a/polaris/utils.py b/polaris/utils.py index 6d33347c..0a1472ab 100644 --- a/polaris/utils.py +++ b/polaris/utils.py @@ -1,6 +1,9 @@ """This module defines helpers for various endpoints.""" +import base64 import json import codecs +import time +from urllib.parse import urlparse import uuid from datetime import datetime, timezone from logging import getLogger @@ -270,10 +273,15 @@ def extract_sep9_fields(args): sep9_args[field] = args.get(field) return sep9_args +def compute_callback_signature(callback_url: str, callback_body: str) -> str: + callback_time = int(time.time()) + sig_payload = f"{callback_time}.{urlparse(callback_url).netloc}.{callback_body}" + signature = base64.b64encode(Keypair.from_secret(settings.SIGNING_SEED).sign(sig_payload.encode())).decode() + return f"t={callback_time}, s={signature}" def make_on_change_callback( transaction: Transaction, timeout: Optional[int] = None -) -> RequestsResponse: +) -> Optional[RequestsResponse]: """ Makes a POST request to `transaction.on_change_callback`, a URL provided by the client. The request will time out in @@ -293,13 +301,20 @@ def make_on_change_callback( raise ValueError("invalid or missing on_change_callback") if not timeout: timeout = settings.CALLBACK_REQUEST_TIMEOUT + callback_body = {"transaction": TransactionSerializer(transaction).data} + try: + signature_header_value = compute_callback_signature(transaction.on_change_callback, callback_body) + except ValueError: # + logger.error(f"unable to parse host of transaction.on_change_callback for transaction {transaction.id}") + return None + headers = {"Signature": signature_header_value} return post( url=transaction.on_change_callback, - json=TransactionSerializer(transaction).data, + json=callback_body, timeout=timeout, + headers=headers ) - def maybe_make_callback(transaction: Transaction, timeout: Optional[int] = None): """ Makes the on_change_callback request if present on the transaciton and @@ -315,13 +330,13 @@ def maybe_make_callback(transaction: Transaction, timeout: Optional[int] = None) except RequestException as e: logger.error(f"Callback request raised {e.__class__.__name__}: {str(e)}") else: - if not callback_resp.ok: + if callback_resp and not callback_resp.ok: logger.error(f"Callback request returned {callback_resp.status_code}") async def make_on_change_callback_async( transaction: Transaction, timeout: Optional[int] = None -) -> ClientResponse: +) -> Optional[ClientResponse]: if ( not transaction.on_change_callback or transaction.on_change_callback.lower() == "postmessage" @@ -330,11 +345,19 @@ async def make_on_change_callback_async( if not timeout: timeout = settings.CALLBACK_REQUEST_TIMEOUT timeout_obj = aiohttp.ClientTimeout(total=timeout) + callback_body = {"transaction": TransactionSerializer(transaction).data} + try: + signature_header_value = compute_callback_signature(transaction.on_change_callback, callback_body) + except ValueError: # + logger.error(f"unable to parse host of transaction.on_change_callback for transaction {transaction.id}") + return None + headers = {"Signature": signature_header_value} async with aiohttp.ClientSession(timeout=timeout_obj) as session: return await session.post( url=transaction.on_change_callback, - json=TransactionSerializer(transaction).data, + json=callback_body, timeout=timeout, + headers=headers ) @@ -357,7 +380,7 @@ async def maybe_make_callback_async( except RequestException as e: logger.error(f"Callback request raised {e.__class__.__name__}: {str(e)}") else: - if not callback_resp.ok: + if callback_resp and not callback_resp.ok: logger.error(f"Callback request returned {callback_resp.status}")