From 8cf8b1feca7ca13a440467f00de6b0b3da6d41b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jeremy=20Lain=C3=A9?= Date: Fri, 28 Jun 2024 14:14:11 +0200 Subject: [PATCH] Add more type definitions for `SSL` module, check with mypy We want to ensure that from now on, any new public API comes with proper type definitions. --- pyproject.toml | 15 +--- src/OpenSSL/SSL.py | 183 ++++++++++++++++++++++++++++----------------- 2 files changed, 120 insertions(+), 78 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 91fca19d..7afe704a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,25 +10,18 @@ source = [ ] [tool.coverage.report] +exclude_also = [ + "assert False", +] show_missing = true [tool.mypy] warn_unused_configs = true follow_imports = "skip" strict = true -exclude = ['SSL\.py$'] - -[[tool.mypy.overrides]] -module = "OpenSSL.crypto" -warn_return_any = false -disallow_any_expr = false - -[[tool.mypy.overrides]] -module = "OpenSSL.rand" -warn_return_any = false [[tool.mypy.overrides]] -module = "OpenSSL._util" +module = "OpenSSL.*" warn_return_any = false [[tool.mypy.overrides]] diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py index 5e42289a..6a33d9e4 100644 --- a/src/OpenSSL/SSL.py +++ b/src/OpenSSL/SSL.py @@ -272,7 +272,28 @@ SSL_CB_HANDSHAKE_DONE: int = _lib.SSL_CB_HANDSHAKE_DONE _T = TypeVar("_T") -_SetVerifyCallback = Callable[["Connection", X509, int, int, int], bool] + + +class _NoOverlappingProtocols: + pass + + +NO_OVERLAPPING_PROTOCOLS = _NoOverlappingProtocols() + +# Callback types. +_ALPNSelectCallback = Callable[ + [ + "Connection", + typing.Union[List[bytes], _NoOverlappingProtocols], + ], + None, +] +_CookieGenerateCallback = Callable[["Connection"], bytes] +_CookieVerifyCallback = Callable[["Connection", bytes], bool] +_OCSPClientCallback = Callable[["Connection", bytes, Optional[_T]], bool] +_OCSPServerCallback = Callable[["Connection", Optional[_T]], bytes] +_PassphraseCallback = Callable[[int, bool, Optional[_T]], bytes] +_VerifyCallback = Callable[["Connection", X509, int, int, int], bool] class X509VerificationCodes: @@ -464,11 +485,11 @@ class _VerifyHelper(_CallbackExceptionHelper): callback. """ - def __init__(self, callback): + def __init__(self, callback: _VerifyCallback) -> None: _CallbackExceptionHelper.__init__(self) @wraps(callback) - def wrapper(ok, store_ctx): + def wrapper(ok, store_ctx): # type: ignore[no-untyped-def] x509 = _lib.X509_STORE_CTX_get_current_cert(store_ctx) _lib.X509_up_ref(x509) cert = X509._from_raw_x509_ptr(x509) @@ -498,19 +519,16 @@ def wrapper(ok, store_ctx): ) -NO_OVERLAPPING_PROTOCOLS = object() - - class _ALPNSelectHelper(_CallbackExceptionHelper): """ Wrap a callback such that it can be used as an ALPN selection callback. """ - def __init__(self, callback): + def __init__(self, callback: _ALPNSelectCallback) -> None: _CallbackExceptionHelper.__init__(self) @wraps(callback) - def wrapper(ssl, out, outlen, in_, inlen, arg): + def wrapper(ssl, out, outlen, in_, inlen, arg): # type: ignore[no-untyped-def] try: conn = Connection._reverse_mapping[ssl] @@ -584,11 +602,11 @@ class _OCSPServerCallbackHelper(_CallbackExceptionHelper): This helper implements the server side. """ - def __init__(self, callback): + def __init__(self, callback: _OCSPServerCallback[Any]) -> None: _CallbackExceptionHelper.__init__(self) @wraps(callback) - def wrapper(ssl, cdata): + def wrapper(ssl, cdata): # type: ignore[no-untyped-def] try: conn = Connection._reverse_mapping[ssl] @@ -651,11 +669,11 @@ class _OCSPClientCallbackHelper(_CallbackExceptionHelper): This helper implements the client side. """ - def __init__(self, callback): + def __init__(self, callback: _OCSPClientCallback[Any]) -> None: _CallbackExceptionHelper.__init__(self) @wraps(callback) - def wrapper(ssl, cdata): + def wrapper(ssl, cdata): # type: ignore[no-untyped-def] try: conn = Connection._reverse_mapping[ssl] @@ -689,11 +707,11 @@ def wrapper(ssl, cdata): class _CookieGenerateCallbackHelper(_CallbackExceptionHelper): - def __init__(self, callback): + def __init__(self, callback: _CookieGenerateCallback) -> None: _CallbackExceptionHelper.__init__(self) @wraps(callback) - def wrapper(ssl, out, outlen): + def wrapper(ssl, out, outlen): # type: ignore[no-untyped-def] try: conn = Connection._reverse_mapping[ssl] cookie = callback(conn) @@ -712,11 +730,11 @@ def wrapper(ssl, out, outlen): class _CookieVerifyCallbackHelper(_CallbackExceptionHelper): - def __init__(self, callback): + def __init__(self, callback: _CookieVerifyCallback) -> None: _CallbackExceptionHelper.__init__(self) @wraps(callback) - def wrapper(ssl, c_cookie, cookie_len): + def wrapper(ssl, c_cookie, cookie_len): # type: ignore[no-untyped-def] try: conn = Connection._reverse_mapping[ssl] return callback(conn, bytes(c_cookie[0:cookie_len])) @@ -730,7 +748,7 @@ def wrapper(ssl, c_cookie, cookie_len): ) -def _asFileDescriptor(obj): +def _asFileDescriptor(obj: Any) -> int: fd = None if not isinstance(obj, int): meth = getattr(obj, "fileno", None) @@ -773,11 +791,11 @@ def _make_requires(flag: int, error: str) -> Callable[[_T], _T]: :param error: The string to be used in the exception if the flag is false. """ - def _requires_decorator(func): + def _requires_decorator(func): # type: ignore[no-untyped-def] if not flag: @wraps(func) - def explode(*args, **kwargs): + def explode(*args, **kwargs): # type: ignore[no-untyped-def] raise NotImplementedError(error) return explode @@ -806,7 +824,7 @@ class Session: .. versionadded:: 0.14 """ - pass + _session: Any class Context: @@ -820,7 +838,9 @@ class Context: not be used. """ - _methods: typing.ClassVar[typing.Dict] = { + _methods: typing.ClassVar[ + typing.Dict[int, typing.Tuple[Callable[[], Any], Optional[int]]] + ] = { SSLv23_METHOD: (_lib.TLS_method, None), TLSv1_METHOD: (_lib.TLS_method, TLS1_VERSION), TLSv1_1_METHOD: (_lib.TLS_method, TLS1_1_VERSION), @@ -850,22 +870,30 @@ def __init__(self, method: int) -> None: context = _ffi.gc(context, _lib.SSL_CTX_free) self._context = context - self._passphrase_helper = None - self._passphrase_callback = None - self._passphrase_userdata = None - self._verify_helper = None - self._verify_callback = None + self._passphrase_helper: Optional[_PassphraseHelper] = None + self._passphrase_callback: Optional[_PassphraseCallback[Any]] = None + self._passphrase_userdata: Optional[Any] = None + self._verify_helper: Optional[_VerifyHelper] = None + self._verify_callback: Optional[_VerifyCallback] = None self._info_callback = None self._keylog_callback = None self._tlsext_servername_callback = None self._app_data = None - self._alpn_select_helper = None - self._alpn_select_callback = None - self._ocsp_helper = None - self._ocsp_callback = None - self._ocsp_data = None - self._cookie_generate_helper = None - self._cookie_verify_helper = None + self._alpn_select_helper: Optional[_ALPNSelectHelper] = None + self._alpn_select_callback: Optional[_ALPNSelectCallback] = None + self._ocsp_helper: typing.Union[ + _OCSPClientCallbackHelper, _OCSPServerCallbackHelper, None + ] = None + self._ocsp_callback: typing.Union[ + _OCSPClientCallback[Any], _OCSPServerCallback[Any], None + ] = None + self._ocsp_data: Optional[Any] = None + self._cookie_generate_helper: Optional[ + _CookieGenerateCallbackHelper + ] = None + self._cookie_verify_helper: Optional[_CookieVerifyCallbackHelper] = ( + None + ) self.set_mode(_lib.SSL_MODE_ENABLE_PARTIAL_WRITE) if version is not None: @@ -934,9 +962,11 @@ def load_verify_locations( if not load_result: _raise_current_error() - def _wrap_callback(self, callback): + def _wrap_callback( + self, callback: _PassphraseCallback[_T] + ) -> _PassphraseHelper: @wraps(callback) - def wrapper(size, verify, userdata): + def wrapper(size: int, verify: bool, userdata: Any) -> bytes: return callback(size, verify, self._passphrase_userdata) return _PassphraseHelper( @@ -945,7 +975,7 @@ def wrapper(size, verify, userdata): def set_passwd_cb( self, - callback: Callable[[int, bool, Optional[_T]], bytes], + callback: _PassphraseCallback[_T], userdata: Optional[_T] = None, ) -> None: """ @@ -1024,7 +1054,7 @@ def set_default_verify_paths(self) -> None: _CERTIFICATE_FILE_LOCATIONS, _CERTIFICATE_PATH_LOCATIONS ) - def _check_env_vars_set(self, dir_env_var, file_env_var): + def _check_env_vars_set(self, dir_env_var: str, file_env_var: str) -> bool: """ Check to see if the default cert dir/file environment vars are present. @@ -1035,7 +1065,9 @@ def _check_env_vars_set(self, dir_env_var, file_env_var): or os.environ.get(dir_env_var) is not None ) - def _fallback_default_verify_paths(self, file_path, dir_path): + def _fallback_default_verify_paths( + self, file_path: List[str], dir_path: List[str] + ) -> None: """ Default verify paths are based on the compiled version of OpenSSL. However, when pyca/cryptography is compiled as a manylinux wheel @@ -1243,7 +1275,7 @@ def get_session_cache_mode(self) -> int: return _lib.SSL_CTX_get_session_cache_mode(self._context) def set_verify( - self, mode: int, callback: Optional[_SetVerifyCallback] = None + self, mode: int, callback: Optional[_VerifyCallback] = None ) -> None: """ Set the verification flags for this Context object to *mode* and @@ -1479,7 +1511,7 @@ def set_info_callback( """ @wraps(callback) - def wrapper(ssl, where, return_code): + def wrapper(ssl, where, return_code): # type: ignore[no-untyped-def] callback(Connection._reverse_mapping[ssl], where, return_code) self._info_callback = _ffi.callback( @@ -1505,7 +1537,7 @@ def set_keylog_callback( """ @wraps(callback) - def wrapper(ssl, line): + def wrapper(ssl, line): # type: ignore[no-untyped-def] line = _ffi.string(line) callback(Connection._reverse_mapping[ssl], line) @@ -1531,7 +1563,7 @@ def set_app_data(self, data: Any) -> None: """ self._app_data = data - def get_cert_store(self) -> X509Store: + def get_cert_store(self) -> Optional[X509Store]: """ Get the certificate store for the context. This can be used to add "trusted" certificates without using the @@ -1588,7 +1620,7 @@ def set_tlsext_servername_callback( """ @wraps(callback) - def wrapper(ssl, alert, arg): + def wrapper(ssl, alert, arg): # type: ignore[no-untyped-def] callback(Connection._reverse_mapping[ssl]) return 0 @@ -1653,9 +1685,7 @@ def set_alpn_protos(self, protos: List[bytes]) -> None: ) @_requires_alpn - def set_alpn_select_callback( - self, callback: Callable[["Connection", List[bytes]], None] - ) -> None: + def set_alpn_select_callback(self, callback: _ALPNSelectCallback) -> None: """ Specify a callback function that will be called on the server when a client offers protocols using ALPN. @@ -1675,7 +1705,13 @@ def set_alpn_select_callback( self._context, self._alpn_select_callback, _ffi.NULL ) - def _set_ocsp_callback(self, helper, data): + def _set_ocsp_callback( + self, + helper: typing.Union[ + _OCSPClientCallbackHelper, _OCSPServerCallbackHelper + ], + data: Optional[Any], + ) -> None: """ This internal helper does the common work for ``set_ocsp_server_callback`` and ``set_ocsp_client_callback``, which is @@ -1697,7 +1733,7 @@ def _set_ocsp_callback(self, helper, data): def set_ocsp_server_callback( self, - callback: Callable[["Connection", Optional[_T]], bytes], + callback: _OCSPServerCallback[_T], data: Optional[_T] = None, ) -> None: """ @@ -1719,7 +1755,7 @@ def set_ocsp_server_callback( def set_ocsp_client_callback( self, - callback: Callable[["Connection", bytes, Optional[_T]], bool], + callback: _OCSPClientCallback[_T], data: Optional[_T] = None, ) -> None: """ @@ -1741,14 +1777,18 @@ def set_ocsp_client_callback( helper = _OCSPClientCallbackHelper(callback) self._set_ocsp_callback(helper, data) - def set_cookie_generate_callback(self, callback): + def set_cookie_generate_callback( + self, callback: _CookieGenerateCallback + ) -> None: self._cookie_generate_helper = _CookieGenerateCallbackHelper(callback) _lib.SSL_CTX_set_cookie_generate_cb( self._context, self._cookie_generate_helper.callback, ) - def set_cookie_verify_callback(self, callback): + def set_cookie_verify_callback( + self, callback: _CookieVerifyCallback + ) -> None: self._cookie_verify_helper = _CookieVerifyCallbackHelper(callback) _lib.SSL_CTX_set_cookie_verify_cb( self._context, @@ -1757,7 +1797,9 @@ def set_cookie_verify_callback(self, callback): class Connection: - _reverse_mapping = WeakValueDictionary() + _reverse_mapping: typing.MutableMapping[Any, "Connection"] = ( + WeakValueDictionary() + ) def __init__( self, context: Context, socket: Optional[socket.socket] = None @@ -1819,7 +1861,7 @@ def __init__( ) _openssl_assert(set_result == 1) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """ Look up attributes on the wrapped socket object if they are not found on the Connection object. @@ -1831,7 +1873,7 @@ def __getattr__(self, name): else: return getattr(self._socket, name) - def _raise_ssl_error(self, ssl, result): + def _raise_ssl_error(self, ssl: Any, result: int) -> None: if self._context._verify_helper is not None: self._context._verify_helper.raise_if_problem() if self._context._alpn_select_helper is not None: @@ -1924,7 +1966,7 @@ def get_servername(self) -> Optional[bytes]: return _ffi.string(name) def set_verify( - self, mode: int, callback: Optional[_SetVerifyCallback] = None + self, mode: int, callback: Optional[_VerifyCallback] = None ) -> None: """ Override the Context object's verification flags for this specific @@ -2115,7 +2157,10 @@ def recv(self, bufsiz: int, flags: Optional[int] = None) -> bytes: read = recv def recv_into( - self, buffer, nbytes: Optional[int] = None, flags: Optional[int] = None + self, + buffer: Any, # collections.abc.Buffer once we use Python 3.12+ + nbytes: Optional[int] = None, + flags: Optional[int] = None, ) -> int: """ Receive data on the connection and copy it directly into the provided @@ -2153,7 +2198,7 @@ def recv_into( return result - def _handle_bio_errors(self, bio, result): + def _handle_bio_errors(self, bio: Any, result: int) -> typing.NoReturn: if _lib.BIO_should_retry(bio): if _lib.BIO_should_read(bio): raise WantReadError() @@ -2258,7 +2303,7 @@ def total_renegotiations(self) -> int: """ return _lib.SSL_total_renegotiations(self._ssl) - def connect(self, addr): + def connect(self, addr: Any) -> None: """ Call the :meth:`connect` method of the underlying socket and set up SSL on the socket, using the :class:`Context` object supplied to this @@ -2268,9 +2313,9 @@ def connect(self, addr): :return: What the socket's connect method returns """ _lib.SSL_set_connect_state(self._ssl) - return self._socket.connect(addr) + return self._socket.connect(addr) # type: ignore[return-value, union-attr] - def connect_ex(self, addr) -> int: + def connect_ex(self, addr: Any) -> int: """ Call the :meth:`connect_ex` method of the underlying socket and set up SSL on the socket, using the Context object supplied to this Connection @@ -2280,7 +2325,7 @@ def connect_ex(self, addr) -> int: :param addr: A remove address :return: What the socket's connect_ex method returns """ - connect_ex = self._socket.connect_ex + connect_ex = self._socket.connect_ex # type: ignore[union-attr] self.set_connect_state() return connect_ex(addr) @@ -2294,7 +2339,7 @@ def accept(self) -> Tuple["Connection", Any]: :class:`Connection` object created, and *address* is as returned by the socket's :meth:`accept`. """ - client, addr = self._socket.accept() + client, addr = self._socket.accept() # type: ignore[union-attr] conn = Connection(self._context, client) conn.set_accept_state() return (conn, addr) @@ -2353,6 +2398,7 @@ def DTLSv1_handle_timeout(self) -> bool: result = _lib.DTLSv1_handle_timeout(self._ssl) if result < 0: self._raise_ssl_error(self._ssl, result) + assert False, "unreachable" else: return bool(result) @@ -2381,6 +2427,7 @@ def shutdown(self) -> bool: result = _lib.SSL_shutdown(self._ssl) if result < 0: self._raise_ssl_error(self._ssl, result) + assert False, "unreachable" elif result > 0: return True else: @@ -2429,7 +2476,7 @@ def get_client_ca_list(self) -> List[X509Name]: result.append(pyname) return result - def makefile(self, *args, **kwargs) -> None: + def makefile(self, *args: Any, **kwargs: Any) -> typing.NoReturn: """ The makefile() method is not implemented, since there is no dup semantics for SSL connections @@ -2566,14 +2613,14 @@ def export_keying_material( _openssl_assert(success == 1) return _ffi.buffer(outp, olen)[:] - def sock_shutdown(self, *args, **kwargs): + def sock_shutdown(self, *args: Any, **kwargs: Any) -> None: """ Call the :meth:`shutdown` method of the underlying socket. See :manpage:`shutdown(2)`. :return: What the socket's shutdown() method returns """ - return self._socket.shutdown(*args, **kwargs) + return self._socket.shutdown(*args, **kwargs) # type: ignore[return-value, union-attr] def get_certificate(self) -> Optional[X509]: """ @@ -2599,7 +2646,7 @@ def get_peer_certificate(self) -> Optional[X509]: return None @staticmethod - def _cert_stack_to_list(cert_stack) -> List[X509]: + def _cert_stack_to_list(cert_stack: Any) -> List[X509]: """ Internal helper to convert a STACK_OF(X509) to a list of X509 instances. @@ -2714,7 +2761,9 @@ def set_session(self, session: Session) -> None: result = _lib.SSL_set_session(self._ssl, session._session) _openssl_assert(result == 1) - def _get_finished_message(self, function) -> Optional[bytes]: + def _get_finished_message( + self, function: Callable[[Any, Any, int], int] + ) -> Optional[bytes]: """ Helper to implement :meth:`get_finished` and :meth:`get_peer_finished`.