From be9fab28344bc9c82cefd78d7e2e07fe695b099d Mon Sep 17 00:00:00 2001 From: Erik Larsson Date: Sun, 6 Nov 2022 23:55:51 +0100 Subject: [PATCH] ESAPI: flush handles on close or when handle no longer is referenced Signed-off-by: Erik Larsson --- src/tpm2_pytss/ESAPI.py | 103 +++++++++++++++++++++++++++++------ src/tpm2_pytss/constants.py | 38 +++++++++++++ test/test_esapi.py | 105 ++++++++++++++++++++++++++++++++++++ 3 files changed, 230 insertions(+), 16 deletions(-) diff --git a/src/tpm2_pytss/ESAPI.py b/src/tpm2_pytss/ESAPI.py index 59d588a2..8d142595 100644 --- a/src/tpm2_pytss/ESAPI.py +++ b/src/tpm2_pytss/ESAPI.py @@ -12,6 +12,7 @@ from .TCTILdr import TCTILdr from typing import List, Optional, Tuple, Union +import weakref # Work around this FAPI dependency if FAPI is not present with the constant value _fapi_installed_ = _lib_version_atleast("tss2-fapi", "3.0.0") @@ -84,6 +85,8 @@ class ESAPI: tcti (Union[TCTI, str]): The TCTI context used to connect to the TPM (may be None). This is established using TCTILdr or a tpm2-tools style --tcti string in the format of : where : is optional. Defaults to None. + flush_on_close (bool): Flush transient/session handles on close or when the handle is no longer referenced. + Defaults to True. Returns: An instance of the ESAPI class. @@ -108,7 +111,9 @@ class ESAPI: C Function: Esys_Initialize """ - def __init__(self, tcti: Union[TCTI, str, None] = None): + def __init__( + self, tcti: Union[TCTI, str, None] = None, flush_on_close: bool = True + ): if not isinstance(tcti, (TCTI, type(None), str)): raise TypeError( @@ -129,6 +134,9 @@ def __init__(self, tcti: Union[TCTI, str, None] = None): _chkrc(lib.Esys_Initialize(self._ctx_pp, tctx, ffi.NULL)) self._ctx = self._ctx_pp[0] + self._flush_on_close = flush_on_close + self._ref_handles = dict() + def __enter__(self): return self @@ -149,6 +157,11 @@ def close(self) -> None: C Function: Esys_Finalize """ + for handle, v in [(x, y) for x, y in self._ref_handles.items()]: + tracked, _ = v + if not tracked: + continue + self.flush_context(handle) if self._ctx_pp: lib.Esys_Finalize(self._ctx_pp) self._ctx = ffi.NULL @@ -223,6 +236,7 @@ def tr_from_tpmpublic( self._ctx, handle, session1, session2, session3, obj, ) ) + return ESYS_TR(obj[0]) def tr_close(self, esys_handle: ESYS_TR) -> None: @@ -565,7 +579,7 @@ def start_auth_session( ) ) - return ESYS_TR(session_handle[0]) + return self._handle_with_ref(ESYS_TR(session_handle[0])) def trsess_get_attributes(self, session: ESYS_TR) -> TPMA_SESSION: """Get session attributes. @@ -624,6 +638,14 @@ def trsess_set_attributes( _chkrc(lib.Esys_TRSess_SetAttributes(self._ctx, session, attributes, mask)) + if attributes & mask & TPMA_SESSION.CONTINUESESSION: + self._set_tracking_ref(session, True) + elif ( + mask & TPMA_SESSION.CONTINUESESSION + and not attributes & TPMA_SESSION.CONTINUESESSION + ): + self._set_tracking_ref(session, False) + def trsess_get_nonce_tpm(self, session: ESYS_TR) -> TPM2B_NONCE: """Retrieve the TPM nonce of an Esys_TR session object. @@ -869,7 +891,7 @@ def load( ) ) - return ESYS_TR(object_handle[0]) + return self._handle_with_ref(ESYS_TR(object_handle[0])) def load_external( self, @@ -935,7 +957,7 @@ def load_external( ) ) - return ESYS_TR(object_handle[0]) + return self._handle_with_ref(ESYS_TR(object_handle[0])) def read_public( self, @@ -1310,7 +1332,7 @@ def create_loaded( ) return ( - ESYS_TR(object_handle[0]), + self._handle_with_ref(ESYS_TR(object_handle[0])), TPM2B_PRIVATE(_get_dptr(out_private, lib.Esys_Free)), TPM2B_PUBLIC(_get_dptr(out_public, lib.Esys_Free)), ) @@ -2354,7 +2376,7 @@ def hmac_start( ) ) - return ESYS_TR(sequence_handle[0]) + return self._handle_with_ref(ESYS_TR(sequence_handle[0])) def hash_sequence_start( self, @@ -2414,7 +2436,7 @@ def hash_sequence_start( ) ) - return ESYS_TR(sequence_handle[0]) + return self._handle_with_ref(ESYS_TR(sequence_handle[0])) def sequence_update( self, @@ -2524,6 +2546,7 @@ def sequence_complete( ) ) + self._ref_handles.pop(sequence_handle, None) return ( TPM2B_DIGEST(_get_dptr(result, lib.Esys_Free)), TPMT_TK_HASHCHECK(_get_dptr(validation, lib.Esys_Free)), @@ -2587,6 +2610,8 @@ def event_sequence_complete( results, ) ) + + self._ref_handles.pop(sequence_handle, None) return TPML_DIGEST_VALUES(_get_dptr(results, lib.Esys_Free)) def certify( @@ -4944,7 +4969,7 @@ def create_primary( ) return ( - ESYS_TR(object_handle[0]), + self._handle_with_ref(ESYS_TR(object_handle[0])), TPM2B_PUBLIC(_cdata=_get_dptr(out_public, lib.Esys_Free)), TPM2B_CREATION_DATA(_cdata=_get_dptr(creation_data, lib.Esys_Free)), TPM2B_DIGEST(_cdata=_get_dptr(creation_hash, lib.Esys_Free)), @@ -5675,6 +5700,7 @@ def context_save(self, save_handle: ESYS_TR) -> TPMS_CONTEXT: _check_handle_type(save_handle, "save_handle") context = ffi.new("TPMS_CONTEXT **") _chkrc(lib.Esys_ContextSave(self._ctx, save_handle, context)) + self._ref_handles.pop(save_handle, None) return TPMS_CONTEXT(_get_dptr(context, lib.Esys_Free)) def context_load(self, context: TPMS_CONTEXT) -> ESYS_TR: @@ -5704,7 +5730,15 @@ def context_load(self, context: TPMS_CONTEXT) -> ESYS_TR: loaded_handle = ffi.new("ESYS_TR *") _chkrc(lib.Esys_ContextLoad(self._ctx, context_cdata, loaded_handle)) - return ESYS_TR(loaded_handle[0]) + handle = ESYS_TR(loaded_handle[0]) + if context.savedHandle.type == TPM2_HT.TRANSIENT: + handle = self._handle_with_ref(handle) + elif ( + context.savedHandle.type in (TPM2_HT.HMAC_SESSION, TPM2_HT.POLICY_SESSION) + and self.trsess_get_attributes(handle) & TPMA_SESSION.CONTINUESESSION + ): + handle = self._handle_with_ref(handle) + return handle def flush_context(self, flush_handle: ESYS_TR) -> None: """Invoke the TPM2_FlushContext command. @@ -5725,6 +5759,7 @@ def flush_context(self, flush_handle: ESYS_TR) -> None: TPM Command: TPM2_FlushContext """ + self._ref_handles.pop(flush_handle, None) _check_handle_type(flush_handle, "flush_handle") _chkrc(lib.Esys_FlushContext(self._ctx, flush_handle)) @@ -6807,20 +6842,17 @@ def load_blob( Returns: ESYS_TR: The ESAPI handle to the loaded object. """ - esys_handle = ffi.new("ESYS_TR *") if type_ == FAPI_ESYSBLOB.CONTEXTLOAD: - offs = ffi.new("size_t *", 0) - key_ctx = ffi.new("TPMS_CONTEXT *") - _chkrc(lib.Tss2_MU_TPMS_CONTEXT_Unmarshal(data, len(data), offs, key_ctx)) - _chkrc(lib.Esys_ContextLoad(self._ctx, key_ctx, esys_handle)) + context, _ = TPMS_CONTEXT.unmarshal(data) + esys_handle = self.context_load(context) elif type_ == FAPI_ESYSBLOB.DESERIALIZE: - _chkrc(lib.Esys_TR_Deserialize(self._ctx, data, len(data), esys_handle)) + esys_handle = self.tr_deserialize(data) else: raise ValueError( f"Expected type_ to be FAPI_ESYSBLOB.CONTEXTLOAD or FAPI_ESYSBLOB.DESERIALIZE, got {type_}" ) - return ESYS_TR(esys_handle[0]) + return esys_handle def tr_serialize(self, esys_handle: ESYS_TR) -> bytes: """Serialization of an ESYS_TR into a byte buffer. @@ -6892,6 +6924,45 @@ def tr_deserialize(self, buffer: bytes) -> ESYS_TR: return ESYS_TR(esys_handle[0]) + def _incr_handle_ref(self, handle): + if handle not in self._ref_handles: + return + tracked, count = self._ref_handles[handle] + count += 1 + self._ref_handles[handle] = (tracked, count) + + def _decr_handle_ref(self, handle): + if handle not in self._ref_handles: + return + tracked, count = self._ref_handles[handle] + count -= 1 + if count <= 0 and tracked: + self.flush_context(handle) + else: + self._ref_handles[handle] = (tracked, count) + + def _set_tracking_ref(self, handle, tracked): + if handle not in self._ref_handles: + return + _, count = self._ref_handles[handle] + self._ref_handles[handle] = (tracked, count) + + def _handle_with_ref(self, handle: ESYS_TR) -> ESYS_TR: + """Set up weak references between ESAPI context and handle. + + If ectx._flush_on_close is True create weak references so we can flush on close/unref. + """ + if not self._flush_on_close: + return handle + handle._ectx_ref = weakref.ref(self) + # Each handle has a bool which is True if it's tracked or not and a reference counter. + # The reason for having a tracked/untracked bool is due to the TPM flushing sessions + # after use unless TPMA_SESSION.CONTINUESESSION is set and as it can be toggled on or off + # during runtime we need to keep the reference counter alive for the handle. + self._ref_handles[handle] = (True, 0) + handle._incr_ref() + return handle + @staticmethod def _fixup_hierarchy(hierarchy: ESYS_TR) -> Union[TPM2_RH, ESYS_TR]: """Fixup ESYS_TR values to TPM2_RH constants to work around tpm2-tss API change in 3.0.0. diff --git a/src/tpm2_pytss/constants.py b/src/tpm2_pytss/constants.py index 01a2b844..c2766460 100644 --- a/src/tpm2_pytss/constants.py +++ b/src/tpm2_pytss/constants.py @@ -392,6 +392,13 @@ class ESYS_TR(TPM_FRIENDLY_INT): RH_PLATFORM = lib.ESYS_TR_RH_PLATFORM RH_PLATFORM_NV = lib.ESYS_TR_RH_PLATFORM_NV + _ectx_ref = None + + def __init__(self, *args, **kwargs): + if len(args) == 1 and isinstance(args[0], self.__class__): + self._ectx_ref = args[0]._ectx_ref + self._incr_ref() + def serialize(self, ectx: "ESAPI") -> bytes: """Same as see tpm2_pytss.ESAPI.tr_serialize @@ -436,6 +443,37 @@ def close(self, ectx: "ESAPI"): """ return ectx.tr_close(self) + def __del__(self): + """Flush handle on deletion. + + If the ESYS handle has reference to an ESAPI instance flush the handle + when the there is no more references to the handle. + """ + self._decr_ref() + + def __enter__(self): + pass + + def __exit__(self, _type, value, traceback) -> None: + self._decr_ref() + + def _incr_ref(self): + if self._ectx_ref is None: + return + ectx = self._ectx_ref() + if ectx is None: + return + ectx._incr_handle_ref(self) + + def _decr_ref(self): + ectx_ref = getattr(self, "_ectx_ref", None) + if ectx_ref is None: + return + ectx = ectx_ref() + if ectx is None or not ectx._ctx: + return + ectx._decr_handle_ref(self) + @TPM_FRIENDLY_INT._fix_const_type class TPM2_RH(TPM_FRIENDLY_INT): diff --git a/test/test_esapi.py b/test/test_esapi.py index ea8cfc83..f7ac9892 100644 --- a/test/test_esapi.py +++ b/test/test_esapi.py @@ -4798,6 +4798,111 @@ def test_trsess_get_attributes(self): attrs = self.ectx.trsess_get_attributes(session) self.assertEqual(attrs, TPMA_SESSION.AUDIT | TPMA_SESSION.AUDITEXCLUSIVE) + def test_flush_on_close_exit(self): + sym = TPMT_SYM_DEF(algorithm=TPM2_ALG.NULL,) + + session = self.ectx.start_auth_session( + tpm_key=ESYS_TR.NONE, + bind=ESYS_TR.NONE, + session_type=TPM2_SE.HMAC, + symmetric=sym, + auth_hash=TPM2_ALG.SHA256, + ) + + handles = list() + more = True + while more: + more, data = self.ectx.get_capability( + TPM2_CAP.HANDLES, TPM2_HC.HMAC_SESSION_FIRST + ) + handles += list(data.data.handles) + self.assertEqual(len(handles), 1) + + with session as s: + pass + + handles = list() + more = True + while more: + more, data = self.ectx.get_capability( + TPM2_CAP.HANDLES, TPM2_HC.HMAC_SESSION_FIRST + ) + handles += list(data.data.handles) + self.assertEqual(len(handles), 0) + + def test_flush_on_close_exit_with_copy(self): + sym = TPMT_SYM_DEF(algorithm=TPM2_ALG.NULL,) + + session = self.ectx.start_auth_session( + tpm_key=ESYS_TR.NONE, + bind=ESYS_TR.NONE, + session_type=TPM2_SE.HMAC, + symmetric=sym, + auth_hash=TPM2_ALG.SHA256, + ) + + handles = list() + more = True + while more: + more, data = self.ectx.get_capability( + TPM2_CAP.HANDLES, TPM2_HC.HMAC_SESSION_FIRST + ) + handles += list(data.data.handles) + self.assertEqual(len(handles), 1) + + session_copy = ESYS_TR(session) + with session as s: + pass + + handles = list() + more = True + while more: + more, data = self.ectx.get_capability( + TPM2_CAP.HANDLES, TPM2_HC.HMAC_SESSION_FIRST + ) + handles += list(data.data.handles) + self.assertEqual(len(handles), 1) + + del session_copy + + def test_flush_on_close_off(self): + tcti = self.ectx.get_tcti() + + ectx = ESAPI(tcti, flush_on_close=False) + + sym = TPMT_SYM_DEF(algorithm=TPM2_ALG.NULL,) + + session = ectx.start_auth_session( + tpm_key=ESYS_TR.NONE, + bind=ESYS_TR.NONE, + session_type=TPM2_SE.HMAC, + symmetric=sym, + auth_hash=TPM2_ALG.SHA256, + ) + + handles = list() + more = True + while more: + more, data = ectx.get_capability( + TPM2_CAP.HANDLES, TPM2_HC.HMAC_SESSION_FIRST + ) + handles += list(data.data.handles) + self.assertEqual(len(handles), 1) + + with session as s: + pass + + handles = list() + more = True + while more: + more, data = ectx.get_capability( + TPM2_CAP.HANDLES, TPM2_HC.HMAC_SESSION_FIRST + ) + handles += list(data.data.handles) + self.assertEqual(len(handles), 1) + + ectx.close() + if __name__ == "__main__": unittest.main()