Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto flush #398

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 117 additions & 16 deletions src/tpm2_pytss/ESAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .TCTILdr import TCTILdr

from typing import List, Optional, Tuple, Union
import weakref

# Work around this FAPI dependency if FAPI is not present with the constant value
_fapi_installed_ = _lib_version_atleast("tss2-fapi", "3.0.0")
Expand Down Expand Up @@ -84,6 +85,8 @@ class ESAPI:
tcti (Union[TCTI, str]): The TCTI context used to connect to the TPM (may be None).
This is established using TCTILdr or a tpm2-tools style --tcti string in the format
of <tcti-name>:<tcti-conf> where :<tcti-conf> is optional. Defaults to None.
flush_on_close (bool): Flush transient/session handles on close or when the handle is no longer referenced.
Defaults to True.

Returns:
An instance of the ESAPI class.
Expand All @@ -108,7 +111,9 @@ class ESAPI:
C Function: Esys_Initialize
"""

def __init__(self, tcti: Union[TCTI, str, None] = None):
def __init__(
self, tcti: Union[TCTI, str, None] = None, flush_on_close: bool = True
):

if not isinstance(tcti, (TCTI, type(None), str)):
raise TypeError(
Expand All @@ -129,6 +134,9 @@ def __init__(self, tcti: Union[TCTI, str, None] = None):
_chkrc(lib.Esys_Initialize(self._ctx_pp, tctx, ffi.NULL))
self._ctx = self._ctx_pp[0]

self._flush_on_close = flush_on_close
self._ref_handles = dict()

def __enter__(self):
return self

Expand All @@ -149,6 +157,12 @@ def close(self) -> None:

C Function: Esys_Finalize
"""
if self._flush_on_close:
for handle, v in [(x, y) for x, y in self._ref_handles.items()]:
tracked, _ = v
if not tracked:
continue
self.flush_context(handle)
if self._ctx_pp:
lib.Esys_Finalize(self._ctx_pp)
self._ctx = ffi.NULL
Expand Down Expand Up @@ -223,6 +237,7 @@ def tr_from_tpmpublic(
self._ctx, handle, session1, session2, session3, obj,
)
)

return ESYS_TR(obj[0])

def tr_close(self, esys_handle: ESYS_TR) -> None:
Expand Down Expand Up @@ -565,7 +580,32 @@ def start_auth_session(
)
)

return ESYS_TR(session_handle[0])
return self._handle_with_ref(ESYS_TR(session_handle[0]))

def trsess_get_attributes(self, session: ESYS_TR) -> TPMA_SESSION:
"""Get session attributes.

Get a sessions attributes.

Args:
session (ESYS_TR): The session handle.

Raises:
TypeError: If a parameter is not of an expected type.
TSS2_Exception: Any of the various TSS2_RC's the lower layers can return.

Returns:
The attributes as a TPMA_SESSION.

C_Function: Esys_TRSess_GetAttributes
"""
_check_handle_type(session, "session")

attributes = ffi.new("TPMA_SESSION *")

_chkrc(lib.Esys_TRSess_GetAttributes(self._ctx, session, attributes))

return TPMA_SESSION(attributes[0])

def trsess_set_attributes(
self, session: ESYS_TR, attributes: Union[TPMA_SESSION, int], mask: int = 0xFF
Expand Down Expand Up @@ -599,6 +639,14 @@ def trsess_set_attributes(

_chkrc(lib.Esys_TRSess_SetAttributes(self._ctx, session, attributes, mask))

if attributes & mask & TPMA_SESSION.CONTINUESESSION:
self._set_tracking_ref(session, True)
elif (
mask & TPMA_SESSION.CONTINUESESSION
and not attributes & TPMA_SESSION.CONTINUESESSION
):
self._set_tracking_ref(session, False)

def trsess_get_nonce_tpm(self, session: ESYS_TR) -> TPM2B_NONCE:
"""Retrieve the TPM nonce of an Esys_TR session object.

Expand Down Expand Up @@ -844,7 +892,7 @@ def load(
)
)

return ESYS_TR(object_handle[0])
return self._handle_with_ref(ESYS_TR(object_handle[0]))

def load_external(
self,
Expand Down Expand Up @@ -910,7 +958,7 @@ def load_external(
)
)

return ESYS_TR(object_handle[0])
return self._handle_with_ref(ESYS_TR(object_handle[0]))

def read_public(
self,
Expand Down Expand Up @@ -1285,7 +1333,7 @@ def create_loaded(
)

return (
ESYS_TR(object_handle[0]),
self._handle_with_ref(ESYS_TR(object_handle[0])),
TPM2B_PRIVATE(_get_dptr(out_private, lib.Esys_Free)),
TPM2B_PUBLIC(_get_dptr(out_public, lib.Esys_Free)),
)
Expand Down Expand Up @@ -2329,7 +2377,7 @@ def hmac_start(
)
)

return ESYS_TR(sequence_handle[0])
return self._handle_with_ref(ESYS_TR(sequence_handle[0]))

def hash_sequence_start(
self,
Expand Down Expand Up @@ -2389,7 +2437,7 @@ def hash_sequence_start(
)
)

return ESYS_TR(sequence_handle[0])
return self._handle_with_ref(ESYS_TR(sequence_handle[0]))

def sequence_update(
self,
Expand Down Expand Up @@ -2499,6 +2547,7 @@ def sequence_complete(
)
)

self._ref_handles.pop(sequence_handle, None)
return (
TPM2B_DIGEST(_get_dptr(result, lib.Esys_Free)),
TPMT_TK_HASHCHECK(_get_dptr(validation, lib.Esys_Free)),
Expand Down Expand Up @@ -2562,6 +2611,8 @@ def event_sequence_complete(
results,
)
)

self._ref_handles.pop(sequence_handle, None)
return TPML_DIGEST_VALUES(_get_dptr(results, lib.Esys_Free))

def certify(
Expand Down Expand Up @@ -4919,7 +4970,7 @@ def create_primary(
)

return (
ESYS_TR(object_handle[0]),
self._handle_with_ref(ESYS_TR(object_handle[0])),
TPM2B_PUBLIC(_cdata=_get_dptr(out_public, lib.Esys_Free)),
TPM2B_CREATION_DATA(_cdata=_get_dptr(creation_data, lib.Esys_Free)),
TPM2B_DIGEST(_cdata=_get_dptr(creation_hash, lib.Esys_Free)),
Expand Down Expand Up @@ -5650,6 +5701,7 @@ def context_save(self, save_handle: ESYS_TR) -> TPMS_CONTEXT:
_check_handle_type(save_handle, "save_handle")
context = ffi.new("TPMS_CONTEXT **")
_chkrc(lib.Esys_ContextSave(self._ctx, save_handle, context))
self._ref_handles.pop(save_handle, None)
return TPMS_CONTEXT(_get_dptr(context, lib.Esys_Free))

def context_load(self, context: TPMS_CONTEXT) -> ESYS_TR:
Expand Down Expand Up @@ -5679,7 +5731,15 @@ def context_load(self, context: TPMS_CONTEXT) -> ESYS_TR:
loaded_handle = ffi.new("ESYS_TR *")
_chkrc(lib.Esys_ContextLoad(self._ctx, context_cdata, loaded_handle))

return ESYS_TR(loaded_handle[0])
handle = ESYS_TR(loaded_handle[0])
if context.savedHandle.type == TPM2_HT.TRANSIENT:
handle = self._handle_with_ref(handle)
elif (
context.savedHandle.type in (TPM2_HT.HMAC_SESSION, TPM2_HT.POLICY_SESSION)
and self.trsess_get_attributes(handle) & TPMA_SESSION.CONTINUESESSION
):
handle = self._handle_with_ref(handle)
return handle

def flush_context(self, flush_handle: ESYS_TR) -> None:
"""Invoke the TPM2_FlushContext command.
Expand All @@ -5700,6 +5760,7 @@ def flush_context(self, flush_handle: ESYS_TR) -> None:
TPM Command: TPM2_FlushContext
"""

self._ref_handles.pop(flush_handle, None)
_check_handle_type(flush_handle, "flush_handle")
_chkrc(lib.Esys_FlushContext(self._ctx, flush_handle))

Expand Down Expand Up @@ -6782,20 +6843,17 @@ def load_blob(
Returns:
ESYS_TR: The ESAPI handle to the loaded object.
"""
esys_handle = ffi.new("ESYS_TR *")
if type_ == FAPI_ESYSBLOB.CONTEXTLOAD:
offs = ffi.new("size_t *", 0)
key_ctx = ffi.new("TPMS_CONTEXT *")
_chkrc(lib.Tss2_MU_TPMS_CONTEXT_Unmarshal(data, len(data), offs, key_ctx))
_chkrc(lib.Esys_ContextLoad(self._ctx, key_ctx, esys_handle))
context, _ = TPMS_CONTEXT.unmarshal(data)
esys_handle = self.context_load(context)
elif type_ == FAPI_ESYSBLOB.DESERIALIZE:
_chkrc(lib.Esys_TR_Deserialize(self._ctx, data, len(data), esys_handle))
esys_handle = self.tr_deserialize(data)
else:
raise ValueError(
f"Expected type_ to be FAPI_ESYSBLOB.CONTEXTLOAD or FAPI_ESYSBLOB.DESERIALIZE, got {type_}"
)

return ESYS_TR(esys_handle[0])
return esys_handle

def tr_serialize(self, esys_handle: ESYS_TR) -> bytes:
"""Serialization of an ESYS_TR into a byte buffer.
Expand Down Expand Up @@ -6867,6 +6925,49 @@ def tr_deserialize(self, buffer: bytes) -> ESYS_TR:

return ESYS_TR(esys_handle[0])

def _incr_handle_ref(self, handle):
if handle not in self._ref_handles:
return
tracked, count = self._ref_handles[handle]
count += 1
self._ref_handles[handle] = (tracked, count)

def _decr_handle_ref(self, handle):
if handle not in self._ref_handles:
return
tracked, count = self._ref_handles[handle]
count -= 1
if count <= 0 and tracked and self._flush_on_close:
self.flush_context(handle)
elif count <= 0 and not tracked:
# If a session is not tracked we don't know if it's flushed by the TPM or not.
# So just drop it from _ref_handles is the counter goes down to zero.
self._ref_handles.pop(handle, None)
else:
self._ref_handles[handle] = (tracked, count)

def _set_tracking_ref(self, handle, tracked):
if handle not in self._ref_handles:
return
_, count = self._ref_handles[handle]
self._ref_handles[handle] = (tracked, count)

def _handle_with_ref(self, handle: ESYS_TR) -> ESYS_TR:
"""Set up weak references between ESAPI context and handle.

If ectx._flush_on_close is True create weak references so we can flush on close/unref.
"""
if not self._flush_on_close:
return handle
handle._ectx_ref = weakref.ref(self)
# Each handle has a bool which is True if it's tracked or not and a reference counter.
# The reason for having a tracked/untracked bool is due to the TPM flushing sessions
# after use unless TPMA_SESSION.CONTINUESESSION is set and as it can be toggled on or off
# during runtime we need to keep the reference counter alive for the handle.
self._ref_handles[handle] = (True, 0)
williamcroberts marked this conversation as resolved.
Show resolved Hide resolved
handle._incr_ref()
return handle

@staticmethod
def _fixup_hierarchy(hierarchy: ESYS_TR) -> Union[TPM2_RH, ESYS_TR]:
"""Fixup ESYS_TR values to TPM2_RH constants to work around tpm2-tss API change in 3.0.0.
Expand Down
35 changes: 35 additions & 0 deletions src/tpm2_pytss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ class ESYS_TR(TPM_FRIENDLY_INT):
RH_PLATFORM = lib.ESYS_TR_RH_PLATFORM
RH_PLATFORM_NV = lib.ESYS_TR_RH_PLATFORM_NV

_ectx_ref = None

def __init__(self, *args, **kwargs):
if len(args) == 1 and isinstance(args[0], self.__class__):
self._ectx_ref = args[0]._ectx_ref
self._incr_ref()

def serialize(self, ectx: "ESAPI") -> bytes:
"""Same as see tpm2_pytss.ESAPI.tr_serialize

Expand Down Expand Up @@ -436,6 +443,34 @@ def close(self, ectx: "ESAPI"):
"""
return ectx.tr_close(self)

def __del__(self):
"""Flush handle on deletion.

If the ESYS handle has reference to an ESAPI instance flush the handle
when the there is no more references to the handle.
"""
self._decr_ref()

def _get_ectx(self):
if self._ectx_ref is None:
return None
ectx = self._ectx_ref()
if ectx is None or not ectx._ctx:
return None
return ectx

def _incr_ref(self):
ectx = self._get_ectx()
if ectx is None:
return
ectx._incr_handle_ref(self)

def _decr_ref(self):
ectx = self._get_ectx()
if ectx is None:
return
ectx._decr_handle_ref(self)


@TPM_FRIENDLY_INT._fix_const_type
class TPM2_RH(TPM_FRIENDLY_INT):
Expand Down
6 changes: 5 additions & 1 deletion src/tpm2_pytss/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
TPM2_ECC_CURVE,
TPM2_SE,
TPM2_HR,
TPM2_HT,
)
from typing import Union, Tuple, Optional
import sys
Expand All @@ -65,7 +66,10 @@ class ParserAttributeError(Exception):
class TPM2_HANDLE(int):
""""A handle to a TPM address"""

pass
@property
def type(self) -> TPM2_HT:
"""TPM2_HT: The handle type"""
return TPM2_HT((self >> 24) & 0xFF)


class TPM_OBJECT(object):
Expand Down
Loading