Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port RSA to rust #9152

Merged
merged 1 commit into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 15 additions & 255 deletions src/cryptography/hazmat/backends/openssl/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from cryptography.hazmat.backends.openssl import aead
from cryptography.hazmat.backends.openssl.ciphers import _CipherContext
from cryptography.hazmat.backends.openssl.cmac import _CMACContext
from cryptography.hazmat.backends.openssl.rsa import (
_RSAPrivateKey,
_RSAPublicKey,
)
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.bindings.openssl import binding
from cryptography.hazmat.primitives import hashes, serialization
Expand Down Expand Up @@ -63,7 +59,6 @@
XTS,
Mode,
)
from cryptography.hazmat.primitives.serialization import ssh
from cryptography.hazmat.primitives.serialization.pkcs12 import (
PBES,
PKCS12Certificate,
Expand Down Expand Up @@ -358,24 +353,7 @@ def generate_rsa_private_key(
self, public_exponent: int, key_size: int
) -> rsa.RSAPrivateKey:
rsa._verify_rsa_parameters(public_exponent, key_size)

rsa_cdata = self._lib.RSA_new()
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)

bn = self._int_to_bn(public_exponent)
bn = self._ffi.gc(bn, self._lib.BN_free)

res = self._lib.RSA_generate_key_ex(
rsa_cdata, key_size, bn, self._ffi.NULL
)
self.openssl_assert(res == 1)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)

# We can skip RSA key validation here since we just generated the key
return _RSAPrivateKey(
self, rsa_cdata, evp_pkey, unsafe_skip_rsa_key_validation=True
)
return rust_openssl.rsa.generate_private_key(public_exponent, key_size)

def generate_rsa_parameters_supported(
self, public_exponent: int, key_size: int
Expand All @@ -401,46 +379,15 @@ def load_rsa_private_numbers(
numbers.public_numbers.e,
numbers.public_numbers.n,
)
rsa_cdata = self._lib.RSA_new()
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
p = self._int_to_bn(numbers.p)
q = self._int_to_bn(numbers.q)
d = self._int_to_bn(numbers.d)
dmp1 = self._int_to_bn(numbers.dmp1)
dmq1 = self._int_to_bn(numbers.dmq1)
iqmp = self._int_to_bn(numbers.iqmp)
e = self._int_to_bn(numbers.public_numbers.e)
n = self._int_to_bn(numbers.public_numbers.n)
res = self._lib.RSA_set0_factors(rsa_cdata, p, q)
self.openssl_assert(res == 1)
res = self._lib.RSA_set0_key(rsa_cdata, n, e, d)
self.openssl_assert(res == 1)
res = self._lib.RSA_set0_crt_params(rsa_cdata, dmp1, dmq1, iqmp)
self.openssl_assert(res == 1)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)

return _RSAPrivateKey(
self,
rsa_cdata,
evp_pkey,
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
return rust_openssl.rsa.from_private_numbers(
numbers, unsafe_skip_rsa_key_validation
)

def load_rsa_public_numbers(
self, numbers: rsa.RSAPublicNumbers
) -> rsa.RSAPublicKey:
rsa._check_public_key_components(numbers.e, numbers.n)
rsa_cdata = self._lib.RSA_new()
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
e = self._int_to_bn(numbers.e)
n = self._int_to_bn(numbers.n)
res = self._lib.RSA_set0_key(rsa_cdata, n, e, self._ffi.NULL)
self.openssl_assert(res == 1)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)

return _RSAPublicKey(self, rsa_cdata, evp_pkey)
return rust_openssl.rsa.from_public_numbers(numbers)

def _create_evp_pkey_gc(self):
evp_pkey = self._lib.EVP_PKEY_new()
Expand Down Expand Up @@ -500,13 +447,8 @@ def _evp_pkey_to_private_key(
key_type = self._lib.EVP_PKEY_id(evp_pkey)

if key_type == self._lib.EVP_PKEY_RSA:
rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
return _RSAPrivateKey(
self,
rsa_cdata,
evp_pkey,
return rust_openssl.rsa.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey)),
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
)
elif (
Expand Down Expand Up @@ -573,10 +515,9 @@ def _evp_pkey_to_public_key(self, evp_pkey) -> PublicKeyTypes:
key_type = self._lib.EVP_PKEY_id(evp_pkey)

if key_type == self._lib.EVP_PKEY_RSA:
rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
return rust_openssl.rsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif (
key_type == self._lib.EVP_PKEY_RSA_PSS
and not self._lib.CRYPTOGRAPHY_IS_LIBRESSL
Expand Down Expand Up @@ -733,7 +674,9 @@ def load_pem_public_key(self, data: bytes) -> PublicKeyTypes:
if rsa_cdata != self._ffi.NULL:
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
return rust_openssl.rsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
else:
self._handle_key_loading_error()

Expand Down Expand Up @@ -796,7 +739,9 @@ def load_der_public_key(self, data: bytes) -> PublicKeyTypes:
if rsa_cdata != self._ffi.NULL:
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
return rust_openssl.rsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
else:
self._handle_key_loading_error()

Expand Down Expand Up @@ -984,191 +929,6 @@ def elliptic_curve_exchange_algorithm_supported(
algorithm, ec.ECDH
)

def _private_key_bytes(
self,
encoding: serialization.Encoding,
format: serialization.PrivateFormat,
encryption_algorithm: serialization.KeySerializationEncryption,
key,
evp_pkey,
cdata,
) -> bytes:
# validate argument types
if not isinstance(encoding, serialization.Encoding):
raise TypeError("encoding must be an item from the Encoding enum")
if not isinstance(format, serialization.PrivateFormat):
raise TypeError(
"format must be an item from the PrivateFormat enum"
)
if not isinstance(
encryption_algorithm, serialization.KeySerializationEncryption
):
raise TypeError(
"Encryption algorithm must be a KeySerializationEncryption "
"instance"
)

# validate password
if isinstance(encryption_algorithm, serialization.NoEncryption):
password = b""
elif isinstance(
encryption_algorithm, serialization.BestAvailableEncryption
):
password = encryption_algorithm.password
if len(password) > 1023:
raise ValueError(
"Passwords longer than 1023 bytes are not supported by "
"this backend"
)
elif (
isinstance(
encryption_algorithm, serialization._KeySerializationEncryption
)
and encryption_algorithm._format
is format
is serialization.PrivateFormat.OpenSSH
):
password = encryption_algorithm.password
else:
raise ValueError("Unsupported encryption type")

# PKCS8 + PEM/DER
if format is serialization.PrivateFormat.PKCS8:
if encoding is serialization.Encoding.PEM:
write_bio = self._lib.PEM_write_bio_PKCS8PrivateKey
elif encoding is serialization.Encoding.DER:
write_bio = self._lib.i2d_PKCS8PrivateKey_bio
else:
raise ValueError("Unsupported encoding for PKCS8")
return self._private_key_bytes_via_bio(
write_bio, evp_pkey, password
)

# TraditionalOpenSSL + PEM/DER
if format is serialization.PrivateFormat.TraditionalOpenSSL:
if self._fips_enabled and not isinstance(
encryption_algorithm, serialization.NoEncryption
):
raise ValueError(
"Encrypted traditional OpenSSL format is not "
"supported in FIPS mode."
)
key_type = self._lib.EVP_PKEY_id(evp_pkey)

if encoding is serialization.Encoding.PEM:
assert key_type == self._lib.EVP_PKEY_RSA
write_bio = self._lib.PEM_write_bio_RSAPrivateKey
return self._private_key_bytes_via_bio(
write_bio, cdata, password
)

if encoding is serialization.Encoding.DER:
if password:
raise ValueError(
"Encryption is not supported for DER encoded "
"traditional OpenSSL keys"
)
assert key_type == self._lib.EVP_PKEY_RSA
write_bio = self._lib.i2d_RSAPrivateKey_bio
return self._bio_func_output(write_bio, cdata)

raise ValueError("Unsupported encoding for TraditionalOpenSSL")

# OpenSSH + PEM
if format is serialization.PrivateFormat.OpenSSH:
if encoding is serialization.Encoding.PEM:
return ssh._serialize_ssh_private_key(
key, password, encryption_algorithm
)

raise ValueError(
"OpenSSH private key format can only be used"
" with PEM encoding"
)

# Anything that key-specific code was supposed to handle earlier,
# like Raw.
raise ValueError("format is invalid with this key")

def _private_key_bytes_via_bio(
self, write_bio, evp_pkey, password
) -> bytes:
if not password:
evp_cipher = self._ffi.NULL
else:
# This is a curated value that we will update over time.
evp_cipher = self._lib.EVP_get_cipherbyname(b"aes-256-cbc")

return self._bio_func_output(
write_bio,
evp_pkey,
evp_cipher,
password,
len(password),
self._ffi.NULL,
self._ffi.NULL,
)

def _bio_func_output(self, write_bio, *args) -> bytes:
bio = self._create_mem_bio_gc()
res = write_bio(bio, *args)
self.openssl_assert(res == 1)
return self._read_mem_bio(bio)

def _public_key_bytes(
self,
encoding: serialization.Encoding,
format: serialization.PublicFormat,
key,
evp_pkey,
cdata,
) -> bytes:
if not isinstance(encoding, serialization.Encoding):
raise TypeError("encoding must be an item from the Encoding enum")
if not isinstance(format, serialization.PublicFormat):
raise TypeError(
"format must be an item from the PublicFormat enum"
)

# SubjectPublicKeyInfo + PEM/DER
if format is serialization.PublicFormat.SubjectPublicKeyInfo:
if encoding is serialization.Encoding.PEM:
write_bio = self._lib.PEM_write_bio_PUBKEY
elif encoding is serialization.Encoding.DER:
write_bio = self._lib.i2d_PUBKEY_bio
else:
raise ValueError(
"SubjectPublicKeyInfo works only with PEM or DER encoding"
)
return self._bio_func_output(write_bio, evp_pkey)

# PKCS1 + PEM/DER
if format is serialization.PublicFormat.PKCS1:
# Only RSA is supported here.
key_type = self._lib.EVP_PKEY_id(evp_pkey)
self.openssl_assert(key_type == self._lib.EVP_PKEY_RSA)

if encoding is serialization.Encoding.PEM:
write_bio = self._lib.PEM_write_bio_RSAPublicKey
elif encoding is serialization.Encoding.DER:
write_bio = self._lib.i2d_RSAPublicKey_bio
else:
raise ValueError("PKCS1 works only with PEM or DER encoding")
return self._bio_func_output(write_bio, cdata)

# OpenSSH + OpenSSH
if format is serialization.PublicFormat.OpenSSH:
if encoding is serialization.Encoding.OpenSSH:
return ssh.serialize_ssh_public_key(key)

raise ValueError(
"OpenSSH format must be used with OpenSSH encoding"
)

# Anything that key-specific code was supposed to handle earlier,
# like Raw, CompressedPoint, UncompressedPoint
raise ValueError("format is invalid with this key")

def dh_supported(self) -> bool:
return not self._lib.CRYPTOGRAPHY_IS_BORINGSSL

Expand Down
Loading
Loading