Skip to content

Commit

Permalink
Enable non-encrypted certificates (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
olehb007 authored Jun 22, 2023
1 parent c20c0f0 commit 95089cd
Showing 1 changed file with 66 additions and 23 deletions.
89 changes: 66 additions & 23 deletions requests_pkcs12.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,18 @@ def create_sslcontext(pkcs12_data, pkcs12_password_bytes, ssl_protocol=default_s
ssl_context = ssl.SSLContext(ssl_protocol)
with tempfile.NamedTemporaryFile(delete=False) as c:
try:
pk_buf = private_key.private_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL,
cryptography.hazmat.primitives.serialization.BestAvailableEncryption(password=pkcs12_password_bytes)
)
if pkcs12_password_bytes is not None:
pk_buf = private_key.private_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL,
cryptography.hazmat.primitives.serialization.BestAvailableEncryption(password=pkcs12_password_bytes)
)
else:
pk_buf = private_key.private_bytes(
cryptography.hazmat.primitives.serialization.Encoding.PEM,
cryptography.hazmat.primitives.serialization.PrivateFormat.TraditionalOpenSSL,
cryptography.hazmat.primitives.serialization.NoEncryption()
)
c.write(pk_buf)
buf = cert.public_bytes(cryptography.hazmat.primitives.serialization.Encoding.PEM)
c.write(buf)
Expand All @@ -77,15 +84,18 @@ def __init__(self, *args, **kwargs):
raise ValueError('Both arguments "pkcs12_data" and "pkcs12_filename" are missing')
if pkcs12_data is not None and pkcs12_filename is not None:
raise ValueError('Argument "pkcs12_data" conflicts with "pkcs12_filename"')
if pkcs12_password is None:
raise ValueError('Argument "pkcs12_password" is missing')
if pkcs12_filename is not None:
with open(pkcs12_filename, 'rb') as pkcs12_file:
pkcs12_data = pkcs12_file.read()
if isinstance(pkcs12_password, bytes):
if pkcs12_password is None:
pkcs12_password_bytes = None
elif isinstance(pkcs12_password, bytes):
pkcs12_password_bytes = pkcs12_password
else:
elif isinstance(pkcs12_password, str):
pkcs12_password_bytes = pkcs12_password.encode('utf8')
else:
raise TypeError('Password must be a None, string or bytes.')

self.ssl_context = create_sslcontext(pkcs12_data, pkcs12_password_bytes, ssl_protocol)
super(Pkcs12Adapter, self).__init__(*args, **kwargs)

Expand Down Expand Up @@ -160,6 +170,31 @@ def post(*args, **kwargs):
def put(*args, **kwargs):
return request('put', *args, **kwargs)

def execute_test_case(test_case_name, test_case, key, cert):
print(f"Testing {test_case_name}")
password = test_case['pkcs12_password']
try:
algorithm = cryptography.hazmat.primitives.serialization.BestAvailableEncryption(password) \
if test_case['pkcs12_password'] is not None else cryptography.hazmat.primitives.serialization.NoEncryption()
pkcs12_data = cryptography.hazmat.primitives.serialization.pkcs12.serialize_key_and_certificates(
name=b'test',
key=key,
cert=cert,
cas=[cert, cert, cert],
encryption_algorithm=algorithm
)
response = get(
'https://example.com/',
pkcs12_data=pkcs12_data,
pkcs12_password=test_case['pkcs12_password']
)
if response.status_code != test_case['expected_status_code']:
raise Exception('Unexpected response: {response!r}'.format(**locals()))
except ValueError as e:
if test_case['expected_exception_message'] is None or str(e) != test_case['expected_exception_message']:
raise(e)


def selftest():
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(public_exponent=65537, key_size=4096)
cert = cryptography.x509.CertificateBuilder().subject_name(
Expand All @@ -183,20 +218,28 @@ def selftest():
cryptography.hazmat.primitives.hashes.SHA512(),
cryptography.hazmat.backends.default_backend()
)
pkcs12_data = cryptography.hazmat.primitives.serialization.pkcs12.serialize_key_and_certificates(
name=b'test',
key=key,
cert=cert,
cas=[cert, cert, cert],
encryption_algorithm=cryptography.hazmat.primitives.serialization.BestAvailableEncryption(b'correcthorsebatterystaple')
)
response = get(
'https://example.com/',
pkcs12_data=pkcs12_data,
pkcs12_password='correcthorsebatterystaple'
)
if response.status_code != 200:
raise Exception('Unexpected response: {response!r}'.format(**locals()))

test_cases = {
"withEncryption": {
"pkcs12_password": b"correcthorsebatterystaple",
"expected_status_code": 200,
"expected_exception_message": None,
},
"withEmptyPassword": {
"pkcs12_password": b"",
"expected_status_code": 200,
"expected_exception_message": "Password must be 1 or more bytes.",
},
"withoutEncryption": {
"pkcs12_password": None,
"expected_status_code": 200,
"expected_exception_message": None,
},
}

for test_case_name, test_case in test_cases.items():
execute_test_case(test_case_name, test_case, key, cert)

print('Selftest succeeded.')

if __name__ == '__main__':
Expand Down

0 comments on commit 95089cd

Please sign in to comment.