Skip to content

Commit

Permalink
ESAPI: flush handles on close or when handle no longer is referenced
Browse files Browse the repository at this point in the history
This commit adds the following functionality:
* Flush transient/session handles when the ESAPI context is closed
  if flush_on_close is True (default is True)
* Flush transient/session handles when the destructor is called on
  a ESYS_TR if flush_on_close is True

Signed-off-by: Erik Larsson <who+github@cnackers.org>
  • Loading branch information
whooo authored and William Roberts committed Nov 22, 2022
1 parent 01375a1 commit 6806d4c
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 17 deletions.
108 changes: 92 additions & 16 deletions src/tpm2_pytss/ESAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .TCTILdr import TCTILdr

from typing import List, Optional, Tuple, Union
import weakref

# Work around this FAPI dependency if FAPI is not present with the constant value
_fapi_installed_ = _lib_version_atleast("tss2-fapi", "3.0.0")
Expand Down Expand Up @@ -84,6 +85,8 @@ class ESAPI:
tcti (Union[TCTI, str]): The TCTI context used to connect to the TPM (may be None).
This is established using TCTILdr or a tpm2-tools style --tcti string in the format
of <tcti-name>:<tcti-conf> where :<tcti-conf> is optional. Defaults to None.
flush_on_close (bool): Flush transient/session handles on close or when the handle is no longer referenced.
Defaults to True.
Returns:
An instance of the ESAPI class.
Expand All @@ -108,7 +111,9 @@ class ESAPI:
C Function: Esys_Initialize
"""

def __init__(self, tcti: Union[TCTI, str, None] = None):
def __init__(
self, tcti: Union[TCTI, str, None] = None, flush_on_close: bool = True
):

if not isinstance(tcti, (TCTI, type(None), str)):
raise TypeError(
Expand All @@ -129,6 +134,9 @@ def __init__(self, tcti: Union[TCTI, str, None] = None):
_chkrc(lib.Esys_Initialize(self._ctx_pp, tctx, ffi.NULL))
self._ctx = self._ctx_pp[0]

self._flush_on_close = flush_on_close
self._ref_handles = dict()

def __enter__(self):
return self

Expand All @@ -149,6 +157,12 @@ def close(self) -> None:
C Function: Esys_Finalize
"""
if self._flush_on_close:
for handle, v in [(x, y) for x, y in self._ref_handles.items()]:
tracked, _ = v
if not tracked:
continue
self.flush_context(handle)
if self._ctx_pp:
lib.Esys_Finalize(self._ctx_pp)
self._ctx = ffi.NULL
Expand Down Expand Up @@ -223,6 +237,7 @@ def tr_from_tpmpublic(
self._ctx, handle, session1, session2, session3, obj,
)
)

return ESYS_TR(obj[0])

def tr_close(self, esys_handle: ESYS_TR) -> None:
Expand Down Expand Up @@ -565,7 +580,7 @@ def start_auth_session(
)
)

return ESYS_TR(session_handle[0])
return self._handle_with_ref(ESYS_TR(session_handle[0]))

def trsess_get_attributes(self, session: ESYS_TR) -> TPMA_SESSION:
"""Get session attributes.
Expand Down Expand Up @@ -624,6 +639,14 @@ def trsess_set_attributes(

_chkrc(lib.Esys_TRSess_SetAttributes(self._ctx, session, attributes, mask))

if attributes & mask & TPMA_SESSION.CONTINUESESSION:
self._set_tracking_ref(session, True)
elif (
mask & TPMA_SESSION.CONTINUESESSION
and not attributes & TPMA_SESSION.CONTINUESESSION
):
self._set_tracking_ref(session, False)

def trsess_get_nonce_tpm(self, session: ESYS_TR) -> TPM2B_NONCE:
"""Retrieve the TPM nonce of an Esys_TR session object.
Expand Down Expand Up @@ -869,7 +892,7 @@ def load(
)
)

return ESYS_TR(object_handle[0])
return self._handle_with_ref(ESYS_TR(object_handle[0]))

def load_external(
self,
Expand Down Expand Up @@ -935,7 +958,7 @@ def load_external(
)
)

return ESYS_TR(object_handle[0])
return self._handle_with_ref(ESYS_TR(object_handle[0]))

def read_public(
self,
Expand Down Expand Up @@ -1310,7 +1333,7 @@ def create_loaded(
)

return (
ESYS_TR(object_handle[0]),
self._handle_with_ref(ESYS_TR(object_handle[0])),
TPM2B_PRIVATE(_get_dptr(out_private, lib.Esys_Free)),
TPM2B_PUBLIC(_get_dptr(out_public, lib.Esys_Free)),
)
Expand Down Expand Up @@ -2354,7 +2377,7 @@ def hmac_start(
)
)

return ESYS_TR(sequence_handle[0])
return self._handle_with_ref(ESYS_TR(sequence_handle[0]))

def hash_sequence_start(
self,
Expand Down Expand Up @@ -2414,7 +2437,7 @@ def hash_sequence_start(
)
)

return ESYS_TR(sequence_handle[0])
return self._handle_with_ref(ESYS_TR(sequence_handle[0]))

def sequence_update(
self,
Expand Down Expand Up @@ -2524,6 +2547,7 @@ def sequence_complete(
)
)

self._ref_handles.pop(sequence_handle, None)
return (
TPM2B_DIGEST(_get_dptr(result, lib.Esys_Free)),
TPMT_TK_HASHCHECK(_get_dptr(validation, lib.Esys_Free)),
Expand Down Expand Up @@ -2587,6 +2611,8 @@ def event_sequence_complete(
results,
)
)

self._ref_handles.pop(sequence_handle, None)
return TPML_DIGEST_VALUES(_get_dptr(results, lib.Esys_Free))

def certify(
Expand Down Expand Up @@ -4944,7 +4970,7 @@ def create_primary(
)

return (
ESYS_TR(object_handle[0]),
self._handle_with_ref(ESYS_TR(object_handle[0])),
TPM2B_PUBLIC(_cdata=_get_dptr(out_public, lib.Esys_Free)),
TPM2B_CREATION_DATA(_cdata=_get_dptr(creation_data, lib.Esys_Free)),
TPM2B_DIGEST(_cdata=_get_dptr(creation_hash, lib.Esys_Free)),
Expand Down Expand Up @@ -5675,6 +5701,7 @@ def context_save(self, save_handle: ESYS_TR) -> TPMS_CONTEXT:
_check_handle_type(save_handle, "save_handle")
context = ffi.new("TPMS_CONTEXT **")
_chkrc(lib.Esys_ContextSave(self._ctx, save_handle, context))
self._ref_handles.pop(save_handle, None)
return TPMS_CONTEXT(_get_dptr(context, lib.Esys_Free))

def context_load(self, context: TPMS_CONTEXT) -> ESYS_TR:
Expand Down Expand Up @@ -5704,7 +5731,15 @@ def context_load(self, context: TPMS_CONTEXT) -> ESYS_TR:
loaded_handle = ffi.new("ESYS_TR *")
_chkrc(lib.Esys_ContextLoad(self._ctx, context_cdata, loaded_handle))

return ESYS_TR(loaded_handle[0])
handle = ESYS_TR(loaded_handle[0])
if context.savedHandle.type == TPM2_HT.TRANSIENT:
handle = self._handle_with_ref(handle)
elif (
context.savedHandle.type in (TPM2_HT.HMAC_SESSION, TPM2_HT.POLICY_SESSION)
and self.trsess_get_attributes(handle) & TPMA_SESSION.CONTINUESESSION
):
handle = self._handle_with_ref(handle)
return handle

def flush_context(self, flush_handle: ESYS_TR) -> None:
"""Invoke the TPM2_FlushContext command.
Expand All @@ -5725,6 +5760,7 @@ def flush_context(self, flush_handle: ESYS_TR) -> None:
TPM Command: TPM2_FlushContext
"""

self._ref_handles.pop(flush_handle, None)
_check_handle_type(flush_handle, "flush_handle")
_chkrc(lib.Esys_FlushContext(self._ctx, flush_handle))

Expand Down Expand Up @@ -6807,20 +6843,17 @@ def load_blob(
Returns:
ESYS_TR: The ESAPI handle to the loaded object.
"""
esys_handle = ffi.new("ESYS_TR *")
if type_ == FAPI_ESYSBLOB.CONTEXTLOAD:
offs = ffi.new("size_t *", 0)
key_ctx = ffi.new("TPMS_CONTEXT *")
_chkrc(lib.Tss2_MU_TPMS_CONTEXT_Unmarshal(data, len(data), offs, key_ctx))
_chkrc(lib.Esys_ContextLoad(self._ctx, key_ctx, esys_handle))
context, _ = TPMS_CONTEXT.unmarshal(data)
esys_handle = self.context_load(context)
elif type_ == FAPI_ESYSBLOB.DESERIALIZE:
_chkrc(lib.Esys_TR_Deserialize(self._ctx, data, len(data), esys_handle))
esys_handle = self.tr_deserialize(data)
else:
raise ValueError(
f"Expected type_ to be FAPI_ESYSBLOB.CONTEXTLOAD or FAPI_ESYSBLOB.DESERIALIZE, got {type_}"
)

return ESYS_TR(esys_handle[0])
return esys_handle

def tr_serialize(self, esys_handle: ESYS_TR) -> bytes:
"""Serialization of an ESYS_TR into a byte buffer.
Expand Down Expand Up @@ -6892,6 +6925,49 @@ def tr_deserialize(self, buffer: bytes) -> ESYS_TR:

return ESYS_TR(esys_handle[0])

def _incr_handle_ref(self, handle):
if handle not in self._ref_handles:
return
tracked, count = self._ref_handles[handle]
count += 1
self._ref_handles[handle] = (tracked, count)

def _decr_handle_ref(self, handle):
if handle not in self._ref_handles:
return
tracked, count = self._ref_handles[handle]
count -= 1
if count <= 0 and tracked and self._flush_on_close:
self.flush_context(handle)
elif count <= 0 and not tracked:
# If a session is not tracked we don't know if it's flushed by the TPM or not.
# So just drop it from _ref_handles is the counter goes down to zero.
self._ref_handles.pop(handle, None)
else:
self._ref_handles[handle] = (tracked, count)

def _set_tracking_ref(self, handle, tracked):
if handle not in self._ref_handles:
return
_, count = self._ref_handles[handle]
self._ref_handles[handle] = (tracked, count)

def _handle_with_ref(self, handle: ESYS_TR) -> ESYS_TR:
"""Set up weak references between ESAPI context and handle.
If ectx._flush_on_close is True create weak references so we can flush on close/unref.
"""
if not self._flush_on_close:
return handle
handle._ectx_ref = weakref.ref(self)
# Each handle has a bool which is True if it's tracked or not and a reference counter.
# The reason for having a tracked/untracked bool is due to the TPM flushing sessions
# after use unless TPMA_SESSION.CONTINUESESSION is set and as it can be toggled on or off
# during runtime we need to keep the reference counter alive for the handle.
self._ref_handles[handle] = (True, 0)
handle._incr_ref()
return handle

@staticmethod
def _fixup_hierarchy(hierarchy: ESYS_TR) -> Union[TPM2_RH, ESYS_TR]:
"""Fixup ESYS_TR values to TPM2_RH constants to work around tpm2-tss API change in 3.0.0.
Expand Down
35 changes: 35 additions & 0 deletions src/tpm2_pytss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ class ESYS_TR(TPM_FRIENDLY_INT):
RH_PLATFORM = lib.ESYS_TR_RH_PLATFORM
RH_PLATFORM_NV = lib.ESYS_TR_RH_PLATFORM_NV

_ectx_ref = None

def __init__(self, *args, **kwargs):
if len(args) == 1 and isinstance(args[0], self.__class__):
self._ectx_ref = args[0]._ectx_ref
self._incr_ref()

def serialize(self, ectx: "ESAPI") -> bytes:
"""Same as see tpm2_pytss.ESAPI.tr_serialize
Expand Down Expand Up @@ -436,6 +443,34 @@ def close(self, ectx: "ESAPI"):
"""
return ectx.tr_close(self)

def __del__(self):
"""Flush handle on deletion.
If the ESYS handle has reference to an ESAPI instance flush the handle
when the there is no more references to the handle.
"""
self._decr_ref()

def _get_ectx(self):
if self._ectx_ref is None:
return None
ectx = self._ectx_ref()
if ectx is None or not ectx._ctx:
return None
return ectx

def _incr_ref(self):
ectx = self._get_ectx()
if ectx is None:
return
ectx._incr_handle_ref(self)

def _decr_ref(self):
ectx = self._get_ectx()
if ectx is None:
return
ectx._decr_handle_ref(self)


@TPM_FRIENDLY_INT._fix_const_type
class TPM2_RH(TPM_FRIENDLY_INT):
Expand Down
Loading

0 comments on commit 6806d4c

Please sign in to comment.