Skip to content

Commit

Permalink
vmsdk/python/tests: add tests for TDX
Browse files Browse the repository at this point in the history
This patch mainly adds some tests for TDX.
And it refactors some corresponding code accordingly.

Signed-off-by: zhongjie <zhongjie.shi@intel.com>
  • Loading branch information
intelzhongjie committed Jan 16, 2024
1 parent 904333a commit 3615022
Show file tree
Hide file tree
Showing 13 changed files with 363 additions and 79 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/vmsdk-test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@ jobs:
- name: Run PyTest for VMSDK
run: |
set -ex
sudo su -c "source setupenv.sh && python3 -m pytest -v ./vmsdk/python/tests/test_sdk.py"
# Set PYTHONDONTWRITEBYTECODE and --no-cacheprovider to prevent
# generated some intermediate files by root. Othwerwise, these
# files will fail the action/checkout in the next round of running
# due to the permission issue.
sudo su -c "source setupenv.sh && \
pushd vmsdk/python/tests && \
export PYTHONDONTWRITEBYTECODE=1 && \
python3 -m pytest -p no:cacheprovider -v test_sdk.py && \
popd"
7 changes: 6 additions & 1 deletion common/python/cctrusted_base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
raise NotImplementedError("Inherited SDK class should implement this.")

@abstractmethod
def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
20 changes: 1 addition & 19 deletions common/python/cctrusted_base/imr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from abc import ABC, abstractmethod
from cctrusted_base.tcg import TcgDigest, TcgAlgorithmRegistry
from cctrusted_base.tcg import TcgDigest

class TcgIMR(ABC):
"""Common Integrated Measurement Register class."""
Expand Down Expand Up @@ -56,21 +56,3 @@ def is_valid(self):
"""
return self._index != TcgIMR._INVALID_IMR_INDEX and \
self._index <= self.max_index

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
23 changes: 23 additions & 0 deletions common/python/cctrusted_base/tdx/rtmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
RTMR (Runtime Measurement Register).
"""

from cctrusted_base.imr import TcgIMR
from cctrusted_base.tcg import TcgAlgorithmRegistry

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

RTMR_COUNT = 4
"""Intel TDX TDREPORT provides the 4 measurement registers by default."""

RTMR_LENGTH_BY_BYTES = 48
"""RTMR length by bytes."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)
12 changes: 12 additions & 0 deletions common/python/cctrusted_base/tpm/pcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
PCR (Platform Configuration Register).
"""

from cctrusted_base.imr import TcgIMR

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
3 changes: 2 additions & 1 deletion vmsdk/python/cc_imr_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(CCTrustedVmSdk.inst().get_measurement_count()):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
imr = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
digest_obj = imr.digest(alg.alg_id)

hash_str = ""
for hash_item in digest_obj.hash:
Expand Down
2 changes: 1 addition & 1 deletion vmsdk/python/cc_quote_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
level=logging.NOTSET,
format="%(name)s %(levelname)-8s %(message)s"
)
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
quote = CCTrustedVmSdk.inst().get_quote()
if quote is not None:
quote.dump(args.out_format == OUT_FORMAT_RAW)
else:
Expand Down
3 changes: 2 additions & 1 deletion vmsdk/python/cctrusted_vm/cvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import struct
import fcntl
from abc import abstractmethod
from cctrusted_base.imr import TdxRTMR,TcgIMR
from cctrusted_base.imr import TcgIMR
from cctrusted_base.quote import Quote
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.common import TDX_VERSION_1_0, TDX_VERSION_1_5
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_base.tdx.quote import TdxQuoteReq10, TdxQuoteReq15
from cctrusted_base.tdx.report import TdxReportReq10, TdxReportReq15

Expand Down
11 changes: 8 additions & 3 deletions vmsdk/python/cctrusted_vm/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
if algo_id is None or algo_id is TcgAlgorithmRegistry.TPM_ALG_ERROR:
algo_id = self._cvm.default_algo_id

return self._cvm.imrs[imr_index].digest(algo_id)

def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
return self._cvm.imrs[imr_index]

def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
77 changes: 77 additions & 0 deletions vmsdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Local conftest.py containing directory-specific hook implementations."""

import pytest
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.cvm import ConfidentialVM
from cctrusted_vm.sdk import CCTrustedVmSdk
import tdx_check

cnf_default_alg = {
ConfidentialVM.TYPE_CC_TDX: TcgAlgorithmRegistry.TPM_ALG_SHA384
}
"""Configurations of default algorithm.
The configurations could be different for different confidential VMs.
e.g. TDX use sha384 as the default.
"""

cnf_measurement_cnt = {
ConfidentialVM.TYPE_CC_TDX: TdxRTMR.RTMR_COUNT
}
"""Configurations of measurement count.
The configurations could be different for different confidential VMs.
"""

cnf_measurement_check = {
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_measurement_imrs
}
"""Configurations of measurement check functions.
The configurations could be different for different confidential VMs.
"""

cnf_quote_check = {
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_quote_rtmrs
}
"""Configurations of quote check functions.
The configurations could be different for different confidential VMs.
"""

@pytest.fixture(scope="module")
def vm_sdk():
"""Get VMSDK instance."""
return CCTrustedVmSdk.inst()

@pytest.fixture(scope="module")
def cc_type():
"""Get the type of current confidential VM."""
return ConfidentialVM.detect_cc_type()

@pytest.fixture(scope="module")
def default_alg_id(cc_type):
"""Get default algorithm."""
return cnf_default_alg[cc_type]

@pytest.fixture(scope="module")
def measurement_count(cc_type):
"""Get measurement count."""
return cnf_measurement_cnt[cc_type]

def default_check():
"""Default check."""
assert True

@pytest.fixture(scope="module")
def check_measurement(cc_type):
"""Get measurement check."""
checker = cnf_measurement_check[cc_type]
if checker is not None:
return checker
return default_check

@pytest.fixture(scope="module")
def check_quote(cc_type):
"""Get measurement check."""
checker = cnf_quote_check[cc_type]
if checker is not None:
return checker
return default_check
77 changes: 77 additions & 0 deletions vmsdk/python/tests/tdx_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""TDX specific test."""

from hashlib import sha384
from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent
from cctrusted_base.tdx.quote import TdxQuote, TdxQuoteBody
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.sdk import CCTrustedVmSdk

def _replay_eventlog():
"""Get RTMRs from event log by replay."""
rtmr_len = TdxRTMR.RTMR_LENGTH_BY_BYTES
rtmr_cnt = TdxRTMR.RTMR_COUNT
rtmrs = [bytearray(rtmr_len)] * rtmr_cnt
event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs
assert event_logs is not None
for event in event_logs:
if isinstance(event, TcgImrEvent):
sha384_algo = sha384()
sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash)
rtmrs[event.imr_index] = sha384_algo.digest()
return rtmrs

def _check_imr(imr_index: int, alg_id: int, rtmr: bytes):
"""Check individual IMR.
Compare the 4 IMR hash with the hash derived by replay event log. They are
expected to be same.
Args:
imr_index: an integer specified the IMR index.
alg_id: an integer specified the hash algorithm.
rtmr: bytes of RTMR data for comparison.
"""
assert 0 <= imr_index < TdxRTMR.RTMR_COUNT
assert rtmr is not None
assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id])
assert imr is not None
digest_obj = imr.digest(alg_id)
assert digest_obj is not None
digest_alg_id = digest_obj.alg.alg_id
assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
digest_hash = digest_obj.hash
assert digest_hash is not None
assert digest_hash == rtmr, \
f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}"

def tdx_check_measurement_imrs():
"""Test measurement result.
The test is done by compare the measurement register against the value
derived by replay eventlog.
"""
alg = CCTrustedVmSdk.inst().get_default_algorithms()
rtmrs = _replay_eventlog()
_check_imr(0, alg.alg_id, rtmrs[0])
_check_imr(1, alg.alg_id, rtmrs[1])
_check_imr(2, alg.alg_id, rtmrs[2])
_check_imr(3, alg.alg_id, rtmrs[3])

def tdx_check_quote_rtmrs():
"""Test quote result.
The test is done by compare the RTMRs in quote body against the value
derived by replay eventlog.
"""
quote = CCTrustedVmSdk.inst().get_quote()
assert quote is not None
assert isinstance(quote, TdxQuote)
body = quote.body
assert body is not None
assert isinstance(body, TdxQuoteBody)
rtmrs = _replay_eventlog()
assert body.rtmr0 == rtmrs[0], \
"RTMR0 doesn't equal the replay from event log!"
assert body.rtmr1 == rtmrs[1], \
"RTMR1 doesn't equal the replay from event log!"
assert body.rtmr2 == rtmrs[2], \
"RTMR2 doesn't equal the replay from event log!"
assert body.rtmr3 == rtmrs[3], \
"RTMR3 doesn't equal the replay from event log!"
Loading

0 comments on commit 3615022

Please sign in to comment.