Skip to content

Commit

Permalink
refactored jwt bearer flow
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebv committed Apr 16, 2024
1 parent e2374da commit 8913bea
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 42 deletions.
4 changes: 2 additions & 2 deletions src/aiosalesforce/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
__all__ = [
"Auth",
"ClientCredentialsFlow",
"JwtFlow",
"JwtBearerFlow",
"SoapLogin",
]

from .base import Auth
from .client_credentials_flow import ClientCredentialsFlow
from .jwt_flow import JwtFlow
from .jwt_bearer_flow import JwtBearerFlow
from .soap import SoapLogin
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import pathlib
import time
from html import unescape

from typing import TYPE_CHECKING

import jwt
try:
import jwt

from cryptography.hazmat.primitives import serialization
except ImportError: # pragma: no cover
jwt = None # type: ignore
serialization = None # type: ignore

from aiosalesforce.auth.base import Auth
from aiosalesforce.events import RequestEvent, ResponseEvent
Expand All @@ -13,76 +20,81 @@
from aiosalesforce.client import Salesforce


class JwtFlow(Auth):
class JwtBearerFlow(Auth):
"""
Authenticate using JWT
Authenticate using the OAuth 2.0 JWT Bearer Flow.
https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_jwt_flow.htm&type=5
Parameters
----------
client_id : str
Client ID.
subject: str
Username to authenticate with.
priv_key_file: str
Private key file.
priv_key_passphrase: str | None
Passphrase for private key file, if required
username: str
Username.
private_key_file: str | pathlib.Path
Path to private key file.
private_key_passphrase: str, optional
Passphrase for private key file.
By default assumed to be unencrypted.
timeout : float, optional
Timeout for the access token in seconds.
By default assumed to never expire.
"""

def __init__(
self,
client_id: str,
subject: str,
priv_key_file: str,
priv_key_passphrase: str | None = None,
username: str,
private_key_file: str | pathlib.Path,
private_key_passphrase: str | None = None,
timeout: float | None = None,
) -> None:
super().__init__()
self.client_id = client_id
self.subject = subject
self.priv_key_file = priv_key_file
self.priv_key_passphrase = priv_key_passphrase
self.username = username
self.private_key_file = private_key_file
self.private_key_passphrase = private_key_passphrase
self.timeout = timeout

self._expiration_time: float | None = None

if jwt is None or serialization is None: # pragma: no cover
raise ImportError("Install aiosalesforce[jwt] to use the JwtBearerFlow.")

async def _acquire_new_access_token(self, client: "Salesforce") -> str:
expiration = int(time.time()) + 300
sandbox = "sandbox." in client.base_url.lower()
payload = {
"iss": self.client_id,
"sub": unescape(self.subject),
"aud": f"https://{'test' if sandbox else 'login'}.salesforce.com",
"exp": f"{expiration:.0f}",
"aud": "https://sandbox.salesforce.com"
if "sandbox" in client.base_url.lower()
else "https://login.salesforce.com",
"sub": self.username,
"exp": int(time.time()) + 300,
}
with open(self.priv_key_file, "rb") as file:
priv_key = file.read()

if self.priv_key_passphrase:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

passphrase = self.priv_key_passphrase.encode("utf-8")
with open(self.private_key_file, "rb") as file:
private_key = serialization.load_pem_private_key(
priv_key, password=passphrase, backend=default_backend()
data=file.read(),
password=self.private_key_passphrase.encode("utf-8")
if self.private_key_passphrase is not None
else None,
)
else:
private_key = priv_key

assertion = jwt.encode(
payload, private_key, algorithm="RS256", headers={"alg": "RS256"}
payload,
private_key, # type: ignore
algorithm="RS256",
headers={"alg": "RS256"},
)
token_data = {
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": assertion,
}

request = client.httpx_client.build_request(
"POST",
f"{client.base_url}/services/oauth2/token",
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
data=token_data,
data={
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": assertion,
},
)
await client.event_bus.publish_event(
RequestEvent(
Expand Down Expand Up @@ -117,4 +129,13 @@ async def _acquire_new_access_token(self, client: "Salesforce") -> str:
response=response,
)
)
if self.timeout is not None:
self._expiration_time = time.time() + self.timeout
return json_loads(response.content)["access_token"]

@property
def expired(self) -> bool:
super().expired
if self._expiration_time is None: # pragma: no cover
return False
return self._expiration_time <= time.time()

0 comments on commit 8913bea

Please sign in to comment.