-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecrypt_and_verify.py
83 lines (69 loc) · 2.96 KB
/
decrypt_and_verify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import json
from base64 import b64decode
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
class JWEJWSDecryptor:
def __init__(self, public_key: rsa.RSAPublicKey, private_key: rsa.RSAPrivateKey):
self.public_key = public_key # Sender's public key (for signature verification)
self.private_key = private_key # Receiver's private key (for decryption)
def verify_and_decrypt(self, token: str) -> str:
"""
Verify JWS signature and decrypt JWE payload
"""
# 1. Split JWS into components
# Find the last dot to separate signature
last_dot = token.rindex('.')
signature_b64 = token[last_dot + 1:]
# Find the first dot to separate header
first_dot = token.index('.')
jws_header_b64 = token[:first_dot]
# The payload is everything in between
jws_payload_b64 = token[first_dot + 1:last_dot]
# 2. Verify signature using sender's public key
signing_input = f"{jws_header_b64}.{jws_payload_b64}"
signature = b64decode(signature_b64)
try:
self.public_key.verify(
signature,
signing_input.encode(),
asymmetric_padding.PKCS1v15(),
hashes.SHA256(),
)
except Exception as e:
raise ValueError("Signature verification failed!")
# 3. If signature valid, decrypt the JWE (which is the JWS payload)
return self.decrypt_jwe(jws_payload_b64)
def decrypt_jwe(self, jwe: str) -> str:
"""
Decrypt JWE token
"""
# 1. Split JWE into components
header_b64, encrypted_key_b64, iv_b64, ciphertext_b64, auth_tag_b64 = jwe.split(
"."
)
# 2. Decode all components
header = json.loads(b64decode(header_b64))
encrypted_key = b64decode(encrypted_key_b64)
iv = b64decode(iv_b64)
ciphertext = b64decode(ciphertext_b64)
auth_tag = b64decode(auth_tag_b64)
# 3. Decrypt the CEK using receiver's private key
cek = self.private_key.decrypt(
encrypted_key,
asymmetric_padding.OAEP(
mgf=asymmetric_padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
# 4. Decrypt the payload using the CEK
cipher = Cipher(algorithms.AES(cek), modes.GCM(iv, auth_tag))
decryptor = cipher.decryptor()
# 5. Decrypt and verify authentication tag
try:
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return plaintext.decode()
except Exception as e:
raise ValueError("Decryption failed! Message may have been tampered with.")