Skip to content

Commit

Permalink
added jwt bearer flow tests
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebv committed Apr 17, 2024
1 parent 96cf56e commit 29be988
Showing 1 changed file with 201 additions and 1 deletion.
202 changes: 201 additions & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import asyncio
import pathlib
import time

from unittest.mock import AsyncMock
from urllib.parse import parse_qs

import httpx
import jwt
import pytest
import respx
import time_machine

from aiosalesforce.auth import Auth, ClientCredentialsFlow, SoapLogin
from aiosalesforce.auth import Auth, ClientCredentialsFlow, JwtBearerFlow, SoapLogin
from aiosalesforce.client import Salesforce
from aiosalesforce.events import EventBus
from aiosalesforce.exceptions import AuthenticationError
from aiosalesforce.retries import RetryPolicy
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -285,3 +290,198 @@ async def test_invalid_credentials(
match=r"\[invalid_client_id\] client identifier invalid",
):
await auth.get_access_token(pseudo_client)


class TestJwtBearerFlow:
async def test_auth(
self,
config: dict[str, str],
pseudo_client: Salesforce,
httpx_mock_router: respx.MockRouter,
tmp_path: pathlib.Path,
):
rsa_private_key = generate_private_key(public_exponent=65537, key_size=2048)
private_key_path = tmp_path / "private.pem"
with open(private_key_path, "wb") as f:
f.write(
rsa_private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
)

client_id = "somewhat-secret-client-id"
expected_access_token = "SUPER-SECRET-ACCESS-TOKEN" # noqa: S105

async def side_effect(
request: httpx.Request,
route: respx.Route,
) -> httpx.Response:
data = parse_qs(request.content.decode("utf-8"))
assert data["grant_type"] == ["urn:ietf:params:oauth:grant-type:jwt-bearer"]
assertion = data["assertion"][0]
payload = jwt.decode(
assertion,
rsa_private_key.public_key(),
algorithms=["RS256"],
verify=True,
audience="https://login.salesforce.com",
issuer=client_id,
)
assert payload["sub"] == config["username"]
return httpx.Response(
status_code=200,
json={
"access_token": expected_access_token,
"scope": "full",
"instance_url": "https://example.salesforce.com",
"id": (
"https://login.salesforce.com/id"
"/00Dxx0000000000AAA/005xx0000000xxxAAA"
),
"token_type": "Bearer",
},
)

httpx_mock_router.post(f"{config['base_url']}/services/oauth2/token").mock(
side_effect=side_effect
)

event_hook = AsyncMock()
pseudo_client.event_bus.subscribe_callback(event_hook)
auth = JwtBearerFlow(
client_id=client_id,
username=config["username"],
private_key_file=private_key_path,
)
access_token = await auth.get_access_token(pseudo_client)
assert access_token == expected_access_token
assert not auth.expired
with time_machine.travel(
time.time() + 1e9,
tick=False,
):
# Access token for the JWT Bearer Flow never expires if timeout is not set
assert not auth.expired
assert event_hook.await_count == 3

async def test_expiration(
self,
config: dict[str, str],
pseudo_client: Salesforce,
httpx_mock_router: respx.MockRouter,
tmp_path: pathlib.Path,
):
rsa_private_key = generate_private_key(public_exponent=65537, key_size=2048)
private_key_path = tmp_path / "private.pem"
with open(private_key_path, "wb") as f:
f.write(
rsa_private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
)

client_id = "somewhat-secret-client-id"
expected_access_token = "SUPER-SECRET-ACCESS-TOKEN" # noqa: S105

async def side_effect(
request: httpx.Request,
route: respx.Route,
) -> httpx.Response:
data = parse_qs(request.content.decode("utf-8"))
assert data["grant_type"] == ["urn:ietf:params:oauth:grant-type:jwt-bearer"]
assertion = data["assertion"][0]
payload = jwt.decode(
assertion,
rsa_private_key.public_key(),
algorithms=["RS256"],
verify=True,
audience="https://login.salesforce.com",
issuer=client_id,
)
assert payload["sub"] == config["username"]
return httpx.Response(
status_code=200,
json={
"access_token": expected_access_token,
"scope": "full",
"instance_url": "https://example.salesforce.com",
"id": (
"https://login.salesforce.com/id"
"/00Dxx0000000000AAA/005xx0000000xxxAAA"
),
"token_type": "Bearer",
},
)

httpx_mock_router.post(f"{config['base_url']}/services/oauth2/token").mock(
side_effect=side_effect
)

event_hook = AsyncMock()
pseudo_client.event_bus.subscribe_callback(event_hook)
auth = JwtBearerFlow(
client_id=client_id,
username=config["username"],
private_key_file=private_key_path,
timeout=15 * 60, # 15 minutes
)
access_token = await auth.get_access_token(pseudo_client)
assert event_hook.await_count == 3
assert access_token == expected_access_token
assert auth._expiration_time is not None
assert auth._expiration_time > time.time()
assert not auth.expired
with time_machine.travel(
time.time() + 1e9,
tick=False,
):
assert auth.expired
access_token = await auth.get_access_token(pseudo_client)
assert access_token == expected_access_token
assert not auth.expired
assert event_hook.await_count == 6

async def test_invalid_credentials(
self,
config: dict[str, str],
pseudo_client: Salesforce,
httpx_mock_router: respx.MockRouter,
tmp_path: pathlib.Path,
):
rsa_private_key = generate_private_key(public_exponent=65537, key_size=2048)
private_key_path = tmp_path / "private.pem"
with open(private_key_path, "wb") as f:
f.write(
rsa_private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
)

httpx_mock_router.post(
f"{config['base_url']}/services/oauth2/token",
).mock(
httpx.Response(
status_code=401,
json={
"error": "invalid_client_id",
"error_description": "client identifier invalid",
},
)
)

auth = JwtBearerFlow(
client_id="somewhat-secret-client-id",
username=config["username"],
private_key_file=private_key_path,
)
with pytest.raises(
AuthenticationError,
match=r"\[invalid_client_id\] client identifier invalid",
):
await auth.get_access_token(pseudo_client)

0 comments on commit 29be988

Please sign in to comment.