Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add set_sigalgs to ssl_client / get_peer_signature_nid to set key exchange params #118

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions nassl/_nassl/nassl_SSL.c
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,66 @@ static PyObject* nassl_SSL_set_ciphersuites(nassl_SSL_Object *self, PyObject *ar
}


// SSL_set1_sigalgs() is only available in OpenSSL 1.1.1
static PyObject* nassl_SSL_set1_sigalgs(nassl_SSL_Object *self, PyObject *args)
{
int i = 0;
PyObject *pyListOfOpensslNids;
Py_ssize_t nidsCount = 0;
int *listOfNids;

// Parse the Python list
if (!PyArg_ParseTuple(args, "O!", &PyList_Type, &pyListOfOpensslNids))
{
return NULL;
}

// Extract each NID int from the list
nidsCount = PyList_Size(pyListOfOpensslNids);
listOfNids = (int *) PyMem_Malloc(nidsCount * sizeof(int));
if (listOfNids == NULL)
{
return PyErr_NoMemory();
}

for (i=0; i<nidsCount; i++)
{
PyObject *pyNid;
int nid;

pyNid = PyList_GetItem(pyListOfOpensslNids, i);
if ((pyNid == NULL) || (!PyLong_Check(pyNid)))
{
PyMem_Free(listOfNids);
return NULL;
}
nid = PyLong_AsSize_t(pyNid);
listOfNids[i] = nid;
}

if (SSL_set1_sigalgs(self->ssl, listOfNids, nidsCount) != 1)
{
PyMem_Free(listOfNids);
return raise_OpenSSL_error();
}

PyMem_Free(listOfNids);
Py_RETURN_NONE;
}


static PyObject* nassl_get_peer_signature_nid(nassl_SSL_Object *self)
{
int psig_nid;

if(SSL_get_peer_signature_nid(self->ssl, &psig_nid) != 1)
{
return raise_OpenSSL_error();
}
return PyLong_FromUnsignedLong((long)psig_nid);
}


static PyObject* nassl_SSL_get0_verified_chain(nassl_SSL_Object *self, PyObject *args)
{
STACK_OF(X509) *verifiedCertChain = NULL;
Expand Down Expand Up @@ -1181,6 +1241,12 @@ static PyMethodDef nassl_SSL_Object_methods[] =
{"set_ciphersuites", (PyCFunction)nassl_SSL_set_ciphersuites, METH_VARARGS,
"OpenSSL's SSL_set_ciphersuites()."
},
{"set1_sigalgs", (PyCFunction)nassl_SSL_set1_sigalgs, METH_VARARGS,
"OpenSSL's SSL_set1_sigalgs()."
},
{"get_peer_signature_nid", (PyCFunction)nassl_get_peer_signature_nid, METH_NOARGS,
"OpenSSL's get_peer_signature_nid(). Returns a digest NID"
},
{"get0_verified_chain", (PyCFunction)nassl_SSL_get0_verified_chain, METH_NOARGS,
"OpenSSL's SSL_get0_verified_chain(). Returns an array of _nassl.X509 objects."
},
Expand Down
3 changes: 3 additions & 0 deletions nassl/ephemeral_key_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class OpenSslEvpPkeyEnum(IntEnum):
EC = 408
X25519 = 1034
X448 = 1035
RSA = 6
DSA = 116
RSA_PSS = 912


class OpenSslEcNidEnum(IntEnum):
Expand Down
22 changes: 21 additions & 1 deletion nassl/ssl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from nassl._nassl import WantReadError, OpenSSLError, WantX509LookupError

from enum import IntEnum
from typing import List, Any
from typing import List, Any, Tuple

from typing import Protocol

Expand All @@ -32,6 +32,17 @@ class OpenSslVerifyEnum(IntEnum):
CLIENT_ONCE = 4


class OpenSslDigestNidEnum(IntEnum):
"""SSL digest algorithms used for the signature algorithm, per obj_mac.h."""

MD5 = 4
SHA1 = 64
SHA224 = 675
SHA256 = 672
SHA384 = 673
SHA512 = 674


class OpenSslVersionEnum(IntEnum):
"""SSL version constants."""

Expand Down Expand Up @@ -449,6 +460,15 @@ def set_ciphersuites(self, cipher_suites: str) -> None:
# TODO(AD): Eventually merge this method with get/set_cipher_list()
self._ssl.set_ciphersuites(cipher_suites)

def set_sigalgs(self, sigalgs: List[Tuple[OpenSslDigestNidEnum, OpenSslEvpPkeyEnum]]) -> None:
"""Set the enabled signature algorithms for the key exchange."""
flattened_sigalgs = [item for sublist in sigalgs for item in sublist]
self._ssl.set1_sigalgs(flattened_sigalgs)

def get_peer_signature_nid(self) -> OpenSslDigestNidEnum:
"""Get the digest used for TLS message signing."""
return OpenSslDigestNidEnum(self._ssl.get_peer_signature_nid())

def set_groups(self, supported_groups: List[OpenSslEcNidEnum]) -> None:
"""Specify elliptic curves or DH groups that are supported by the client in descending order."""
self._ssl.set1_groups(supported_groups)
Expand Down
24 changes: 23 additions & 1 deletion tests/ssl_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
OpenSslVerifyEnum,
SslClient,
OpenSSLError,
OpenSslEarlyDataStatusEnum,
OpenSslEarlyDataStatusEnum, OpenSslDigestNidEnum,
)
from nassl.ephemeral_key_info import (
OpenSslEvpPkeyEnum,
Expand Down Expand Up @@ -217,6 +217,8 @@ def test_get_verified_chain(self):

# And when requesting the verified certificate chain, it returns it
assert ssl_client.get_verified_chain()

assert ssl_client.get_peer_signature_nid() == OpenSslDigestNidEnum.SHA256
finally:
ssl_client.shutdown()

Expand Down Expand Up @@ -424,6 +426,26 @@ def test_set_ciphersuites(self):
# And client's cipher suite was used
assert "TLS_CHACHA20_POLY1305_SHA256" == ssl_client.get_current_cipher_name()

def test_set_sigalgs(self):
with ModernOpenSslServer() as server:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
sock.connect((server.hostname, server.port))

ssl_client = SslClient(
ssl_version=OpenSslVersionEnum.TLSV1_3,
underlying_socket=sock,
ssl_verify=OpenSslVerifyEnum.NONE,
)
# These signature algorithms are unsupported
ssl_client.set_sigalgs([
(OpenSslDigestNidEnum.SHA512, OpenSslEvpPkeyEnum.EC)
])

with pytest.raises(OpenSSLError):
ssl_client.do_handshake()
ssl_client.shutdown()

@staticmethod
def _create_tls_1_3_session(server_host: str, server_port: int) -> _nassl.SSL_SESSION:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand Down
Loading