Skip to content

Commit

Permalink
Improved implementation for set_sigalgs
Browse files Browse the repository at this point in the history
  • Loading branch information
mxsasha committed Apr 1, 2024
1 parent 60a9092 commit feddb7d
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 12 deletions.
44 changes: 37 additions & 7 deletions nassl/_nassl/nassl_SSL.c
Original file line number Diff line number Diff line change
Expand Up @@ -813,20 +813,50 @@ static PyObject* nassl_SSL_set_ciphersuites(nassl_SSL_Object *self, PyObject *ar
}


// SSL_set1_sigalgs_list() is only available in OpenSSL 1.1.1
static PyObject* nassl_SSL_set1_sigalgs_list(nassl_SSL_Object *self, PyObject *args)
// SSL_set1_sigalgs() is only available in OpenSSL 1.1.1
static PyObject* nassl_SSL_set1_sigalgs(nassl_SSL_Object *self, PyObject *args)
{
char *sigalgList;
if (!PyArg_ParseTuple(args, "s", &sigalgList))
int i = 0;
PyObject *pyListOfOpensslNids;
Py_ssize_t nidsCount = 0;
int *listOfNids;

// Parse the Python list
if (!PyArg_ParseTuple(args, "O!", &PyList_Type, &pyListOfOpensslNids))
{
return NULL;
}

if (!SSL_set1_sigalgs_list(self->ssl, sigalgList))
// Extract each NID int from the list
nidsCount = PyList_Size(pyListOfOpensslNids);
listOfNids = (int *) PyMem_Malloc(nidsCount * sizeof(int));
if (listOfNids == NULL)
{
return PyErr_NoMemory();
}

for (i=0; i<nidsCount; i++)
{
PyObject *pyNid;
int nid;

pyNid = PyList_GetItem(pyListOfOpensslNids, i);
if ((pyNid == NULL) || (!PyLong_Check(pyNid)))
{
PyMem_Free(listOfNids);
return NULL;
}
nid = PyLong_AsSize_t(pyNid);
listOfNids[i] = nid;
}

if (SSL_set1_sigalgs(self->ssl, listOfNids, nidsCount) != 1)
{
PyMem_Free(listOfNids);
return raise_OpenSSL_error();
}

PyMem_Free(listOfNids);
Py_RETURN_NONE;
}

Expand Down Expand Up @@ -1199,8 +1229,8 @@ static PyMethodDef nassl_SSL_Object_methods[] =
{"set_ciphersuites", (PyCFunction)nassl_SSL_set_ciphersuites, METH_VARARGS,
"OpenSSL's SSL_set_ciphersuites()."
},
{"set1_sigalgs_list", (PyCFunction)nassl_SSL_set1_sigalgs_list, METH_VARARGS,
"OpenSSL's SSL_set1_sigalgs_list()."
{"set1_sigalgs", (PyCFunction)nassl_SSL_set1_sigalgs, METH_VARARGS,
"OpenSSL's SSL_set1_sigalgs()."
},
{"get0_verified_chain", (PyCFunction)nassl_SSL_get0_verified_chain, METH_NOARGS,
"OpenSSL's SSL_get0_verified_chain(). Returns an array of _nassl.X509 objects."
Expand Down
3 changes: 3 additions & 0 deletions nassl/ephemeral_key_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class OpenSslEvpPkeyEnum(IntEnum):
EC = 408
X25519 = 1034
X448 = 1035
RSA = 6
DSA = 116
RSA_PSS = 912


class OpenSslEcNidEnum(IntEnum):
Expand Down
20 changes: 16 additions & 4 deletions nassl/ssl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from nassl._nassl import WantReadError, OpenSSLError, WantX509LookupError

from enum import IntEnum
from typing import List, Any
from typing import List, Any, Tuple

from typing import Protocol

Expand All @@ -32,6 +32,17 @@ class OpenSslVerifyEnum(IntEnum):
CLIENT_ONCE = 4


class OpenSslDigestNidEnum(IntEnum):
"""SSL digest algorithms used for the signature algorithm, per obj_mac.h."""

MD5 = 4
SHA1 = 64
SHA224 = 675
SHA256 = 672
SHA384 = 673
SHA512 = 674


class OpenSslVersionEnum(IntEnum):
"""SSL version constants."""

Expand Down Expand Up @@ -449,9 +460,10 @@ def set_ciphersuites(self, cipher_suites: str) -> None:
# TODO(AD): Eventually merge this method with get/set_cipher_list()
self._ssl.set_ciphersuites(cipher_suites)

def set_sigalgs(self, cipher_suites: str) -> None:
"""Set the enabled signature algorithms, e.g. 'ECDSA+SHA256:RSA+SHA256'"""
self._ssl.set1_sigalgs_list(cipher_suites)
def set_sigalgs(self, sigalgs: List[Tuple[OpenSslDigestNidEnum, OpenSslEvpPkeyEnum]]) -> None:
"""Set the enabled signature algorithms."""
flattened_sigalgs = [item for sublist in sigalgs for item in sublist]
self._ssl.set1_sigalgs(flattened_sigalgs)

def set_groups(self, supported_groups: List[OpenSslEcNidEnum]) -> None:
"""Specify elliptic curves or DH groups that are supported by the client in descending order."""
Expand Down
22 changes: 21 additions & 1 deletion tests/ssl_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
OpenSslVerifyEnum,
SslClient,
OpenSSLError,
OpenSslEarlyDataStatusEnum,
OpenSslEarlyDataStatusEnum, OpenSslDigestNidEnum,
)
from nassl.ephemeral_key_info import (
OpenSslEvpPkeyEnum,
Expand Down Expand Up @@ -424,6 +424,26 @@ def test_set_ciphersuites(self):
# And client's cipher suite was used
assert "TLS_CHACHA20_POLY1305_SHA256" == ssl_client.get_current_cipher_name()

def test_set_sigalgs(self):
with ModernOpenSslServer() as server:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
sock.connect((server.hostname, server.port))

ssl_client = SslClient(
ssl_version=OpenSslVersionEnum.TLSV1_3,
underlying_socket=sock,
ssl_verify=OpenSslVerifyEnum.NONE,
)
# These signature algorithms are unsupported
ssl_client.set_sigalgs([
(OpenSslDigestNidEnum.SHA512, OpenSslEvpPkeyEnum.EC)
])

with pytest.raises(OpenSSLError):
ssl_client.do_handshake()
ssl_client.shutdown()

@staticmethod
def _create_tls_1_3_session(server_host: str, server_port: int) -> _nassl.SSL_SESSION:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand Down

0 comments on commit feddb7d

Please sign in to comment.