Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into dev-confidential
Browse files Browse the repository at this point in the history
  • Loading branch information
olethanh committed Jul 4, 2024
2 parents f30e431 + 4a9eabd commit c74e59d
Show file tree
Hide file tree
Showing 8 changed files with 388 additions and 81 deletions.
5 changes: 4 additions & 1 deletion src/aleph/vm/controllers/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from aleph.vm.controllers.firecracker.snapshots import CompressedDiskVolumeSnapshot
from aleph.vm.network.interfaces import TapInterface
from aleph.vm.utils.logs import make_logs_queue
from aleph.vm.utils.logs import get_past_vm_logs, make_logs_queue

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,3 +118,6 @@ def _journal_stdout_name(self) -> str:
@property
def _journal_stderr_name(self) -> str:
return f"vm-{self.vm_hash}-stderr"

def past_logs(self):
yield from get_past_vm_logs(self._journal_stdout_name, self._journal_stderr_name)
4 changes: 0 additions & 4 deletions src/aleph/vm/controllers/qemu/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,3 @@ async def teardown(self):
if self.tap_interface:
await self.tap_interface.delete()
await self.stop_guest_api()

def print_logs(self) -> None:
"""Print logs to our output for debugging"""
queue = self.get_log_queue()
4 changes: 3 additions & 1 deletion src/aleph/vm/orchestrator/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
operate_confidential_measurement,
operate_erase,
operate_expire,
operate_logs,
operate_reboot,
operate_stop,
stream_logs,
Expand Down Expand Up @@ -103,7 +104,8 @@ def setup_webapp():
web.get("/about/config", about_config),
# /control APIs are used to control the VMs and access their logs
web.post("/control/allocation/notify", notify_allocation),
web.get("/control/machine/{ref}/logs", stream_logs),
web.get("/control/machine/{ref}/stream_logs", stream_logs),
web.get("/control/machine/{ref}/logs", operate_logs),
web.post("/control/machine/{ref}/expire", operate_expire),
web.post("/control/machine/{ref}/stop", operate_stop),
web.post("/control/machine/{ref}/erase", operate_erase),
Expand Down
30 changes: 29 additions & 1 deletion src/aleph/vm/orchestrator/views/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def stream_logs(request: web.Request) -> web.StreamResponse:
queue = None
try:
ws = web.WebSocketResponse()
logger.info(f"starting websocket: {request.path}")
await ws.prepare(request)
try:
await authenticate_websocket_for_vm_or_403(execution, vm_hash, ws)
Expand All @@ -80,6 +81,7 @@ async def stream_logs(request: web.Request) -> web.StreamResponse:
while True:
log_type, message = await queue.get()
assert log_type in ("stdout", "stderr")
logger.debug(message)

await ws.send_json({"type": log_type, "message": message})

Expand All @@ -92,15 +94,41 @@ async def stream_logs(request: web.Request) -> web.StreamResponse:
execution.vm.unregister_queue(queue)


@cors_allow_all
@require_jwk_authentication
async def operate_logs(request: web.Request, authenticated_sender: str) -> web.StreamResponse:
"""Logs of a VM (not streaming)"""
vm_hash = get_itemhash_or_400(request.match_info)
pool: VmPool = request.app["vm_pool"]
execution = get_execution_or_404(vm_hash, pool=pool)
if not is_sender_authorized(authenticated_sender, execution.message):
return web.Response(status=403, body="Unauthorized sender")

response = web.StreamResponse()
response.headers["Content-Type"] = "text/plain"
await response.prepare(request)

for entry in execution.vm.past_logs():
msg = f'{entry["__REALTIME_TIMESTAMP"].isoformat()}> {entry["MESSAGE"]}'
await response.write(msg.encode())
await response.write_eof()
return response


async def authenticate_websocket_for_vm_or_403(execution: VmExecution, vm_hash: ItemHash, ws: web.WebSocketResponse):
"""Authenticate a websocket connection.
Web browsers do not allow setting headers in WebSocket requests, so the authentication
relies on the first message sent by the client.
"""
first_message = await ws.receive_json()
try:
first_message = await ws.receive_json()
except TypeError as error:
logging.exception(error)
raise web.HTTPForbidden(body="Invalid auth package")
credentials = first_message["auth"]
authenticated_sender = await authenticate_websocket_message(credentials)

if is_sender_authorized(authenticated_sender, execution.message):
logger.debug(f"Accepted request to access logs by {authenticated_sender} on {vm_hash}")
return True
Expand Down
26 changes: 25 additions & 1 deletion src/aleph/vm/utils/logs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from typing import Callable, TypedDict
from datetime import datetime
from typing import Callable, Generator, TypedDict

from systemd import journal

Expand All @@ -10,6 +11,7 @@
class EntryDict(TypedDict):
SYSLOG_IDENTIFIER: str
MESSAGE: str
__REALTIME_TIMESTAMP: datetime


def make_logs_queue(stdout_identifier, stderr_identifier, skip_past=False) -> tuple[asyncio.Queue, Callable[[], None]]:
Expand Down Expand Up @@ -56,3 +58,25 @@ def do_cancel():
r.close()

return queue, do_cancel


def get_past_vm_logs(stdout_identifier, stderr_identifier) -> Generator[EntryDict, None, None]:
"""Get existing log for the VM identifiers.
@param stdout_identifier: journald identifier for process stdout
@param stderr_identifier: journald identifier for process stderr
@return: an iterator of log entry
Works by creating a journald reader, and using `add_reader` to call a callback when
data is available for reading.
For more information refer to the sd-journal(3) manpage
and systemd.journal module documentation.
"""
r = journal.Reader()
r.add_match(SYSLOG_IDENTIFIER=stdout_identifier)
r.add_match(SYSLOG_IDENTIFIER=stderr_identifier)

r.seek_head()
for entry in r:
yield entry
86 changes: 86 additions & 0 deletions src/aleph/vm/utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import datetime
import json

import eth_account.messages
import pytest
from eth_account.datastructures import SignedMessage
from eth_account.signers.local import LocalAccount
from jwcrypto import jwk
from jwcrypto.jwa import JWA


@pytest.fixture
def patch_datetime_now(monkeypatch):
"""Fixture for patching the datetime.now() and datetime.utcnow() methods
to return a fixed datetime object.
This fixture creates a subclass of `datetime.datetime` called `mydatetime`,
which overrides the `now()` and `utcnow()` class methods to return a fixed
datetime object specified by `FAKE_TIME`.
"""

class MockDateTime(datetime.datetime):
FAKE_TIME = datetime.datetime(2010, 12, 25, 17, 5, 55)

@classmethod
def now(cls, tz=None, *args, **kwargs):
return cls.FAKE_TIME.replace(tzinfo=tz)

@classmethod
def utcnow(cls, *args, **kwargs):
return cls.FAKE_TIME

monkeypatch.setattr(datetime, "datetime", MockDateTime)
return MockDateTime


async def generate_signer_and_signed_headers_for_operation(
patch_datetime_now, operation_payload: dict
) -> tuple[LocalAccount, dict]:
"""Generate a temporary eth_account for testing and sign the operation with it"""
account = eth_account.Account()
signer_account = account.create()
key = jwk.JWK.generate(
kty="EC",
crv="P-256",
# key_ops=["verify"],
)
pubkey = {
"pubkey": json.loads(key.export_public()),
"alg": "ECDSA",
"domain": "localhost",
"address": signer_account.address,
"expires": (patch_datetime_now.FAKE_TIME + datetime.timedelta(days=1)).isoformat() + "Z",
}
pubkey_payload = json.dumps(pubkey).encode("utf-8").hex()
signable_message = eth_account.messages.encode_defunct(hexstr=pubkey_payload)
signed_message: SignedMessage = signer_account.sign_message(signable_message)
pubkey_signature = to_0x_hex(signed_message.signature)
pubkey_signature_header = json.dumps(
{
"payload": pubkey_payload,
"signature": pubkey_signature,
}
)
payload_as_bytes = json.dumps(operation_payload).encode("utf-8")

payload_signature = JWA.signing_alg("ES256").sign(key, payload_as_bytes)
headers = {
"X-SignedPubKey": pubkey_signature_header,
"X-SignedOperation": json.dumps(
{
"payload": payload_as_bytes.hex(),
"signature": payload_signature.hex(),
}
),
}
return signer_account, headers


def to_0x_hex(b: bytes) -> str:
"""
Convert the bytes to a 0x-prefixed hex string
"""

# force this for compat between different hexbytes versions which behave differenty
# and conflict with other package don't allow us to have the version we want
return "0x" + bytes.hex(b)
80 changes: 7 additions & 73 deletions tests/supervisor/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime
import json
from typing import Any

Expand All @@ -8,22 +7,16 @@
from eth_account.datastructures import SignedMessage
from jwcrypto import jwk, jws
from jwcrypto.common import base64url_decode
from jwcrypto.jwa import JWA

from aleph.vm.orchestrator.views.authentication import (
authenticate_jwk,
require_jwk_authentication,
)


def to_0x_hex(b: bytes) -> str:
"""
Convert the bytes to a 0x-prefixed hex string
"""

# force this for compat between different hexbytes versions which behave differenty
# and conflict with other package don't allow us to have the version we want
return "0x" + bytes.hex(b)
from aleph.vm.utils.test_helpers import (
generate_signer_and_signed_headers_for_operation,
patch_datetime_now,
to_0x_hex,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -67,30 +60,6 @@ async def view(request, authenticated_sender):
assert {"error": "Invalid X-SignedPubKey format"} == r


@pytest.fixture
def patch_datetime_now(monkeypatch):
"""Fixture for patching the datetime.now() and datetime.utcnow() methods
to return a fixed datetime object.
This fixture creates a subclass of `datetime.datetime` called `mydatetime`,
which overrides the `now()` and `utcnow()` class methods to return a fixed
datetime object specified by `FAKE_TIME`.
"""

class MockDateTime(datetime.datetime):
FAKE_TIME = datetime.datetime(2010, 12, 25, 17, 5, 55)

@classmethod
def now(cls, tz=None, *args, **kwargs):
return cls.FAKE_TIME.replace(tzinfo=tz)

@classmethod
def utcnow(cls, *args, **kwargs):
return cls.FAKE_TIME

monkeypatch.setattr(datetime, "datetime", MockDateTime)
return MockDateTime


@pytest.mark.asyncio
async def test_require_jwk_authentication_expired(aiohttp_client):
app = web.Application()
Expand Down Expand Up @@ -251,31 +220,8 @@ async def test_require_jwk_authentication_good_key(aiohttp_client, patch_datetim
"""An HTTP request to a view decorated by `@require_jwk_authentication`
auth correctly a temporary key signed by a wallet and an operation signed by that key"""
app = web.Application()

account = eth_account.Account()
signer_account = account.create()
key = jwk.JWK.generate(
kty="EC",
crv="P-256",
# key_ops=["verify"],
)

pubkey = {
"pubkey": json.loads(key.export_public()),
"alg": "ECDSA",
"address": signer_account.address,
"expires": (patch_datetime_now.FAKE_TIME + datetime.timedelta(days=1)).isoformat() + "Z",
}
pubkey_payload = json.dumps(pubkey).encode("utf-8").hex()
signable_message = eth_account.messages.encode_defunct(hexstr=pubkey_payload)
signed_message: SignedMessage = signer_account.sign_message(signable_message)
pubkey_signature = to_0x_hex(signed_message.signature)
pubkey_signature_header = json.dumps(
{
"payload": pubkey_payload,
"signature": pubkey_signature,
}
)
payload = {"time": "2010-12-25T17:05:55Z", "method": "GET", "path": "/", "domain": "localhost"}
signer_account, headers = await generate_signer_and_signed_headers_for_operation(patch_datetime_now, payload)

@require_jwk_authentication
async def view(request, authenticated_sender):
Expand All @@ -285,18 +231,6 @@ async def view(request, authenticated_sender):
app.router.add_get("", view)
client = await aiohttp_client(app)

payload = {"time": "2010-12-25T17:05:55Z", "method": "GET", "path": "/", "domain": "localhost"}

payload_as_bytes = json.dumps(payload).encode("utf-8")
headers = {"X-SignedPubKey": pubkey_signature_header}
payload_signature = JWA.signing_alg("ES256").sign(key, payload_as_bytes)
headers["X-SignedOperation"] = json.dumps(
{
"payload": payload_as_bytes.hex(),
"signature": payload_signature.hex(),
}
)

resp = await client.get("/", headers=headers)
assert resp.status == 200, await resp.text()

Expand Down
Loading

0 comments on commit c74e59d

Please sign in to comment.