Skip to content

Commit

Permalink
Internal code for performing Key Vault crypto operations locally (#12490
Browse files Browse the repository at this point in the history
)
  • Loading branch information
chlowell authored Sep 14, 2020
1 parent fbee5cb commit 8d9d55e
Show file tree
Hide file tree
Showing 21 changed files with 703 additions and 581 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

import codecs
from base64 import b64encode, b64decode

Expand Down Expand Up @@ -67,29 +66,41 @@ def _b64_to_str(b64str):
return _b64_to_bstr(b64str).decode("utf8")


def _int_to_bigendian_8_bytes(i):
def _int_to_fixed_length_bigendian_bytes(i, length):
"""Convert an integer to a bigendian byte string left-padded with zeroes to a fixed length."""

b = _int_to_bytes(i)

if len(b) > 8:
raise ValueError("the specified integer is to large to be represented by 8 bytes")
if len(b) > length:
raise ValueError("{} is too large to be represented by {} bytes".format(i, length))

if len(b) < 8:
b = (b"\0" * (8 - len(b))) + b
if len(b) < length:
b = (b"\0" * (length - len(b))) + b

return b


def encode_key_vault_ecdsa_signature(signature):
"""
ASN.1 DER encode a Key Vault ECDSA signature.
Key Vault returns ECDSA signatures as the concatenated bytes of two equal-size integers. ``cryptography`` expects
ECDSA signatures be ASN.1 DER encoded.
def ecdsa_to_asn1_der(signature):
"""ASN.1 DER encode an ECDSA signature.
:param bytes signature: ECDSA signature returned by Key Vault
:return: signature encoded for use by ``cryptography``
:param bytes signature: ECDSA signature encoded according to RFC 7518, i.e. the concatenated big-endian bytes of
two integers (as produced by Key Vault)
:return: signature, ASN.1 DER encoded (as expected by ``cryptography``)
"""
mid = len(signature) // 2
r = _bytes_to_int(signature[:mid])
s = _bytes_to_int(signature[mid:])
return utils.encode_dss_signature(r, s)


def asn1_der_to_ecdsa(signature, algorithm):
"""Convert an ASN.1 DER encoded signature to ECDSA encoding.
:param bytes signature: an ASN.1 DER encoded ECDSA signature (as produced by ``cryptography``)
:param _Ecdsa algorithm: signing algorithm which produced ``signature``
:return: signature encoded according to RFC 7518 (as expected by Key Vault)
"""
r, s = utils.decode_dss_signature(signature)
r_bytes = _int_to_fixed_length_bigendian_bytes(r, algorithm.coordinate_length)
s_bytes = _int_to_fixed_length_bigendian_bytes(s, algorithm.coordinate_length)
return r_bytes + s_bytes
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..algorithm import AuthenticatedSymmetricEncryptionAlgorithm
from ..transform import AuthenticatedCryptoTransform
from .._internal import _int_to_bigendian_8_bytes
from .._internal import _int_to_fixed_length_bigendian_bytes


class _AesCbcHmacCryptoTransform(AuthenticatedCryptoTransform):
Expand All @@ -24,7 +24,7 @@ def __init__(self, key, iv, auth_data, auth_tag):
self._cipher = Cipher(algorithms.AES(self._aes_key), modes.CBC(iv), backend=default_backend())
self._tag = auth_tag or bytearray()
self._hmac = hmac.HMAC(self._hmac_key, hash_algo, backend=default_backend())
self._auth_data_length = _int_to_bigendian_8_bytes(len(auth_data) * 8)
self._auth_data_length = _int_to_fixed_length_bigendian_bytes(len(auth_data) * 8, 8)

# prime the hash
self._hmac.update(auth_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..algorithm import AsymmetricEncryptionAlgorithm
from ..transform import CryptoTransform
from ..._enums import KeyWrapAlgorithm


class _AesKeyWrapTransform(CryptoTransform):
Expand Down Expand Up @@ -59,7 +60,7 @@ class AesKw192(_AesKeyWrap):

class AesKw256(_AesKeyWrap):
_key_size = 256
_name = "A256KW"
_name = KeyWrapAlgorithm.aes_256


AesKw128.register()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import abc
import sys

from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric import utils

from ..algorithm import SignatureAlgorithm
from ..transform import SignatureTransform
from ..._enums import SignatureAlgorithm as KeyVaultSignatureAlgorithm

if sys.version_info < (3, 3):
abstractproperty = abc.abstractproperty
else: # abc.abstractproperty is deprecated as of 3.3
import functools

abstractproperty = functools.partial(property, abc.abstractmethod)


class _EcdsaSignatureTransform(SignatureTransform):
Expand All @@ -28,25 +39,33 @@ class _Ecdsa(SignatureAlgorithm):
def create_signature_transform(self, key):
return _EcdsaSignatureTransform(key, self.default_hash_algorithm)

@abstractproperty
def coordinate_length(self):
pass


class Ecdsa256(_Ecdsa):
_name = "ES256K"
_name = KeyVaultSignatureAlgorithm.es256_k
_default_hash_algorithm = hashes.SHA256()
coordinate_length = 32


class Es256(_Ecdsa):
_name = "ES256"
_name = KeyVaultSignatureAlgorithm.es256
_default_hash_algorithm = hashes.SHA256()
coordinate_length = 32


class Es384(_Ecdsa):
_name = "ES384"
_name = KeyVaultSignatureAlgorithm.es384
_default_hash_algorithm = hashes.SHA384()
coordinate_length = 48


class Es512(_Ecdsa):
_name = "ES512"
_name = KeyVaultSignatureAlgorithm.es512
_default_hash_algorithm = hashes.SHA512()
coordinate_length = 66


Ecdsa256.register()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..algorithm import AsymmetricEncryptionAlgorithm
from ..transform import CryptoTransform
from ..._enums import EncryptionAlgorithm


class _Rsa1_5Encryptor(CryptoTransform):
Expand All @@ -20,7 +21,7 @@ def transform(self, data):


class Rsa1_5(AsymmetricEncryptionAlgorithm): # pylint:disable=client-incorrect-naming-convention
_name = "RSA1_5"
_name = EncryptionAlgorithm.rsa1_5

def create_encryptor(self, key):
return _Rsa1_5Encryptor(key)
Expand Down Expand Up @@ -54,7 +55,7 @@ def transform(self, data):


class RsaOaep(AsymmetricEncryptionAlgorithm):
_name = "RSA-OAEP"
_name = EncryptionAlgorithm.rsa_oaep

def create_encryptor(self, key):
return _RsaOaepEncryptor(key, hashes.SHA1)
Expand All @@ -64,7 +65,7 @@ def create_decryptor(self, key):


class RsaOaep256(AsymmetricEncryptionAlgorithm):
_name = "RSA-OAEP-256"
_name = EncryptionAlgorithm.rsa_oaep_256

def create_encryptor(self, key):
return _RsaOaepEncryptor(key, hashes.SHA256)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..algorithm import SignatureAlgorithm
from ..transform import SignatureTransform
from ..._enums import SignatureAlgorithm as KeyVaultSignatureAlgorithm


class RsaSignatureTransform(SignatureTransform):
Expand All @@ -17,7 +18,7 @@ def __init__(self, key, padding_function, hash_algorithm):
self._hash_algorithm = hash_algorithm

def sign(self, digest):
return self._key.sign(digest, self._padding_function(digest), self._hash_algorithm)
return self._key.sign(digest, self._padding_function(digest), utils.Prehashed(self._hash_algorithm))

def verify(self, digest, signature):
self._key.verify(signature, digest, self._padding_function(digest), utils.Prehashed(self._hash_algorithm))
Expand All @@ -37,32 +38,32 @@ def _get_padding(self, digest):


class Ps256(RsaSsaPss):
_name = "PS256"
_name = KeyVaultSignatureAlgorithm.ps256
_default_hash_algorithm = hashes.SHA256()


class Ps384(RsaSsaPss):
_name = "PS384"
_name = KeyVaultSignatureAlgorithm.ps384
_default_hash_algorithm = hashes.SHA384()


class Ps512(RsaSsaPss):
_name = "PS512"
_name = KeyVaultSignatureAlgorithm.ps512
_default_hash_algorithm = hashes.SHA512()


class Rs256(RsaSsaPkcs1v15):
_name = "RS256"
_name = KeyVaultSignatureAlgorithm.rs256
_default_hash_algorithm = hashes.SHA256()


class Rs384(RsaSsaPkcs1v15):
_name = "RS384"
_name = KeyVaultSignatureAlgorithm.rs384
_default_hash_algorithm = hashes.SHA384()


class Rs512(RsaSsaPkcs1v15):
_name = "RS512"
_name = KeyVaultSignatureAlgorithm.rs512
_default_hash_algorithm = hashes.SHA512()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,58 @@
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey,
EllipticCurvePrivateNumbers,
EllipticCurvePublicNumbers,
SECP256R1,
SECP384R1,
SECP521R1,
SECP256K1,
)

from ._internal import _bytes_to_int, encode_key_vault_ecdsa_signature
from ._internal import _bytes_to_int, asn1_der_to_ecdsa, ecdsa_to_asn1_der
from .key import Key
from .algorithms.ecdsa import Es256, Es512, Es384, Ecdsa256
from ... import KeyCurveName

_crypto_crv_to_kv_crv = {"secp256r1": "P-256", "secp384r1": "P-384", "secp521r1": "P-521", "secp256k1": "P-256K"}
_crypto_crv_to_kv_crv = {
"secp256r1": KeyCurveName.p_256,
"secp384r1": KeyCurveName.p_384,
"secp521r1": KeyCurveName.p_521,
"secp256k1": KeyCurveName.p_256_k,
}
_kv_crv_to_crypto_cls = {
"P-256": SECP256R1,
"P-256K": SECP256K1,
"P-384": SECP384R1,
"P-521": SECP521R1,
"SECP256K1": SECP256K1,
KeyCurveName.p_256: SECP256R1,
KeyCurveName.p_256_k: SECP256K1,
KeyCurveName.p_384: SECP384R1,
KeyCurveName.p_521: SECP521R1,
"SECP256K1": SECP256K1, # "SECP256K1" is from Key Vault 2016-10-01
}
_curve_to_default_algo = {
"P-256": Es256.name(),
"P-256K": Ecdsa256.name(),
"P-384": Es384.name(),
"P-521": Es512.name(),
"SECP256K1": Ecdsa256.name(),
KeyCurveName.p_256: Es256.name(),
KeyCurveName.p_256_k: Ecdsa256.name(),
KeyCurveName.p_384: Es384.name(),
KeyCurveName.p_521: Es512.name(),
"SECP256K1": Ecdsa256.name(), # "SECP256K1" is from Key Vault 2016-10-01
}


class EllipticCurveKey(Key):
_supported_signature_algorithms = _curve_to_default_algo.values()
_supported_signature_algorithms = frozenset(_curve_to_default_algo.values())

def __init__(self, x, y, kid=None, curve=None):
def __init__(self, x, y, d=None, kid=None, curve=None):
super(EllipticCurveKey, self).__init__()

self._kid = kid or str(uuid.uuid4())
self._default_algo = _curve_to_default_algo[curve]
curve_cls = _kv_crv_to_crypto_cls[curve]
self._ec_impl = EllipticCurvePublicNumbers(x, y, curve_cls()).public_key(default_backend())

public_numbers = EllipticCurvePublicNumbers(x, y, curve_cls())
self._public_key = public_numbers.public_key(default_backend())
self._private_key = None
if d is not None:
private_numbers = EllipticCurvePrivateNumbers(d, public_numbers)
self._private_key = private_numbers.private_key(default_backend())

@classmethod
def from_jwk(cls, jwk):
Expand All @@ -54,33 +68,39 @@ def from_jwk(cls, jwk):
if not jwk.x or not jwk.y:
raise ValueError("jwk must have values for 'x' and 'y'")

return cls(_bytes_to_int(jwk.x), _bytes_to_int(jwk.y), kid=jwk.kid, curve=jwk.crv)
x = _bytes_to_int(jwk.x)
y = _bytes_to_int(jwk.y)
d = _bytes_to_int(jwk.d) if jwk.d is not None else None
return cls(x, y, d, kid=jwk.kid, curve=jwk.crv)

def is_private_key(self):
return False
return isinstance(self._private_key, EllipticCurvePrivateKey)

def decrypt(self, cipher_text, **kwargs):
raise NotImplementedError()
raise NotImplementedError("Local decryption isn't supported with elliptic curve keys")

def encrypt(self, plain_text, **kwargs):
raise NotImplementedError()
raise NotImplementedError("Local encryption isn't supported with elliptic curve keys")

def wrap_key(self, key, **kwargs):
raise NotImplementedError()
raise NotImplementedError("Local key wrapping isn't supported with elliptic curve keys")

def unwrap_key(self, encrypted_key, **kwargs):
raise NotImplementedError()
raise NotImplementedError("Local key unwrapping isn't supported with elliptic curve keys")

def sign(self, digest, **kwargs):
raise NotImplementedError()
algorithm = self._get_algorithm("sign", **kwargs)
signer = algorithm.create_signature_transform(self._private_key)
signature = signer.sign(digest)
ecdsa_signature = asn1_der_to_ecdsa(signature, algorithm)
return ecdsa_signature

def verify(self, digest, signature, **kwargs):
algorithm = self._get_algorithm("verify", **kwargs)
signer = algorithm.create_signature_transform(self._ec_impl)
dss_signature = encode_key_vault_ecdsa_signature(signature)
signer = algorithm.create_signature_transform(self._public_key)
asn1_signature = ecdsa_to_asn1_der(signature)
try:
# cryptography's verify methods return None, and raise when verification fails
signer.verify(digest, dss_signature)
signer.verify(digest, asn1_signature)
return True
except InvalidSignature:
return False
Loading

0 comments on commit 8d9d55e

Please sign in to comment.