Skip to content

Commit

Permalink
refactor(test): refactor authentication tests (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Oct 12, 2023
1 parent 384203f commit dd134aa
Show file tree
Hide file tree
Showing 44 changed files with 381 additions and 310 deletions.
44 changes: 16 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ The full API of this library can be found in [api.md](https://www.github.com/Mod
from modern_treasury import ModernTreasury

client = ModernTreasury(
# defaults to os.environ.get("MODERN_TREASURY_API_KEY")
api_key="my api key",
# defaults to os.environ.get("MODERN_TREASURY_ORGANIZATION_ID")
organization_id="my-organization-ID",
# defaults to os.environ.get("MODERN_TREASURY_API_KEY")
api_key="My API Key",
)

external_account = client.external_accounts.create(
Expand All @@ -47,8 +48,10 @@ external_account = client.external_accounts.create(
print(external_account.id)
```

While you can provide an `api_key` keyword argument, we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/)
and adding `MODERN_TREASURY_API_KEY="my api key"` to your `.env` file so that your API Key is not stored in source control.
While you can provide a `organization_id` keyword argument,
we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/)
to add `MODERN_TREASURY_ORGANIZATION_ID="my-organization-ID"` to your `.env` file
so that your Organization ID is not stored in source control.

## Async usage

Expand All @@ -58,9 +61,10 @@ Simply import `AsyncModernTreasury` instead of `ModernTreasury` and use `await`
from modern_treasury import AsyncModernTreasury

client = AsyncModernTreasury(
# defaults to os.environ.get("MODERN_TREASURY_API_KEY")
api_key="my api key",
# defaults to os.environ.get("MODERN_TREASURY_ORGANIZATION_ID")
organization_id="my-organization-ID",
# defaults to os.environ.get("MODERN_TREASURY_API_KEY")
api_key="My API Key",
)


Expand Down Expand Up @@ -92,9 +96,7 @@ This library provides auto-paginating iterators with each list response, so you
```python
import modern_treasury

client = ModernTreasury(
organization_id="my-organization-ID",
)
client = ModernTreasury()

all_external_accounts = []
# Automatically fetches more pages as needed.
Expand All @@ -110,9 +112,7 @@ Or, asynchronously:
import asyncio
import modern_treasury

client = AsyncModernTreasury(
organization_id="my-organization-ID",
)
client = AsyncModernTreasury()


async def main() -> None:
Expand Down Expand Up @@ -157,9 +157,7 @@ Nested parameters are dictionaries, typed using `TypedDict`, for example:
```python
from modern_treasury import ModernTreasury

client = ModernTreasury(
organization_id="my-organization-ID",
)
client = ModernTreasury()

client.external_accounts.create(
foo={
Expand All @@ -176,9 +174,7 @@ Request parameters that correspond to file uploads can be passed as `bytes` or a
from pathlib import Path
from modern_treasury import ModernTreasury

client = ModernTreasury(
organization_id="my-organization-ID",
)
client = ModernTreasury()

contents = Path("my/file.txt").read_bytes()
client.documents.create(
Expand All @@ -194,9 +190,7 @@ The async client uses the exact same interface. This example uses `aiofiles` to
import aiofiles
from modern_treasury import ModernTreasury

client = ModernTreasury(
organization_id="my-organization-ID",
)
client = ModernTreasury()

async with aiofiles.open("my/file.txt", mode="rb") as f:
contents = await f.read()
Expand All @@ -221,9 +215,7 @@ All errors inherit from `modern_treasury.APIError`.
import modern_treasury
from modern_treasury import ModernTreasury

client = ModernTreasury(
organization_id="my-organization-ID",
)
client = ModernTreasury()

try:
client.external_accounts.create(
Expand Down Expand Up @@ -268,7 +260,6 @@ from modern_treasury import ModernTreasury
client = ModernTreasury(
# default is 2
max_retries=0,
organization_id="my-organization-ID",
)

# Or, configure per-request:
Expand All @@ -287,13 +278,11 @@ from modern_treasury import ModernTreasury
client = ModernTreasury(
# default is 60s
timeout=20.0,
organization_id="my-organization-ID",
)

# More granular control:
client = ModernTreasury(
timeout=httpx.Timeout(60.0, read=5.0, write=10.0, connect=2.0),
organization_id="my-organization-ID",
)

# Override per-request:
Expand Down Expand Up @@ -338,7 +327,6 @@ client = ModernTreasury(
proxies="http://my.test.proxy.example.com",
transport=httpx.HTTPTransport(local_address="0.0.0.0"),
),
organization_id="my-organization-ID",
)
```

Expand Down
5 changes: 4 additions & 1 deletion src/modern_treasury/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,10 @@ def close(self) -> None:
The client will *not* be usable after this.
"""
self._client.close()
# If an error is thrown while constructing a client, self._client
# may not be present
if hasattr(self, "_client"):
self._client.close()

def __enter__(self: _T) -> _T:
return self
Expand Down
40 changes: 16 additions & 24 deletions src/modern_treasury/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ class ModernTreasury(SyncAPIClient):
def __init__(
self,
*,
organization_id: str | None = None,
webhook_key: str | None = None,
api_key: str | None = os.environ.get("MODERN_TREASURY_API_KEY", None),
organization_id: str | None = os.environ.get("MODERN_TREASURY_ORGANIZATION_ID", None),
webhook_key: str | None = os.environ.get("MODERN_TREASURY_WEBHOOK_KEY", None),
base_url: Optional[str] = None,
api_key: Optional[str] = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
Expand Down Expand Up @@ -125,23 +125,19 @@ def __init__(
- `organization_id` from `MODERN_TREASURY_ORGANIZATION_ID`
- `webhook_key` from `MODERN_TREASURY_WEBHOOK_KEY`
"""
api_key = api_key or os.environ.get("MODERN_TREASURY_API_KEY", None)
if not api_key:
if api_key is None:
raise ModernTreasuryError(
"The api_key client option must be set either by passing api_key to the client or by setting the MODERN_TREASURY_API_KEY environment variable"
)
self.api_key = api_key

organization_id_envvar = os.environ.get("MODERN_TREASURY_ORGANIZATION_ID", None)
organization_id = organization_id or organization_id_envvar or None
if organization_id is None:
raise ValueError(
raise ModernTreasuryError(
"The organization_id client option must be set either by passing organization_id to the client or by setting the MODERN_TREASURY_ORGANIZATION_ID environment variable"
)
self.organization_id = organization_id

webhook_key_envvar = os.environ.get("MODERN_TREASURY_WEBHOOK_KEY", None)
self.webhook_key = webhook_key or webhook_key_envvar or None
self.webhook_key = webhook_key

if base_url is None:
base_url = f"https://app.moderntreasury.com"
Expand Down Expand Up @@ -208,9 +204,9 @@ def auth_headers(self) -> dict[str, str]:
def copy(
self,
*,
api_key: str | None = None,
organization_id: str | None = None,
webhook_key: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
Expand Down Expand Up @@ -264,10 +260,10 @@ def copy(
http_client = http_client or self._client

return self.__class__(
api_key=api_key or self.api_key,
organization_id=organization_id or self.organization_id,
webhook_key=webhook_key or self.webhook_key,
base_url=base_url or str(self.base_url),
api_key=api_key or self.api_key,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
connection_pool_limits=connection_pool_limits,
Expand Down Expand Up @@ -387,10 +383,10 @@ class AsyncModernTreasury(AsyncAPIClient):
def __init__(
self,
*,
organization_id: str | None = None,
webhook_key: str | None = None,
api_key: str | None = os.environ.get("MODERN_TREASURY_API_KEY", None),
organization_id: str | None = os.environ.get("MODERN_TREASURY_ORGANIZATION_ID", None),
webhook_key: str | None = os.environ.get("MODERN_TREASURY_WEBHOOK_KEY", None),
base_url: Optional[str] = None,
api_key: Optional[str] = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
Expand Down Expand Up @@ -420,23 +416,19 @@ def __init__(
- `organization_id` from `MODERN_TREASURY_ORGANIZATION_ID`
- `webhook_key` from `MODERN_TREASURY_WEBHOOK_KEY`
"""
api_key = api_key or os.environ.get("MODERN_TREASURY_API_KEY", None)
if not api_key:
if api_key is None:
raise ModernTreasuryError(
"The api_key client option must be set either by passing api_key to the client or by setting the MODERN_TREASURY_API_KEY environment variable"
)
self.api_key = api_key

organization_id_envvar = os.environ.get("MODERN_TREASURY_ORGANIZATION_ID", None)
organization_id = organization_id or organization_id_envvar or None
if organization_id is None:
raise ValueError(
raise ModernTreasuryError(
"The organization_id client option must be set either by passing organization_id to the client or by setting the MODERN_TREASURY_ORGANIZATION_ID environment variable"
)
self.organization_id = organization_id

webhook_key_envvar = os.environ.get("MODERN_TREASURY_WEBHOOK_KEY", None)
self.webhook_key = webhook_key or webhook_key_envvar or None
self.webhook_key = webhook_key

if base_url is None:
base_url = f"https://app.moderntreasury.com"
Expand Down Expand Up @@ -503,9 +495,9 @@ def auth_headers(self) -> dict[str, str]:
def copy(
self,
*,
api_key: str | None = None,
organization_id: str | None = None,
webhook_key: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
Expand Down Expand Up @@ -559,10 +551,10 @@ def copy(
http_client = http_client or self._client

return self.__class__(
api_key=api_key or self.api_key,
organization_id=organization_id or self.organization_id,
webhook_key=webhook_key or self.webhook_key,
base_url=base_url or str(self.base_url),
api_key=api_key or self.api_key,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
connection_pool_limits=connection_pool_limits,
Expand Down
3 changes: 3 additions & 0 deletions src/modern_treasury/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from ._utils import extract_type_arg as extract_type_arg
from ._utils import is_required_type as is_required_type
from ._utils import is_annotated_type as is_annotated_type
from ._utils import maybe_coerce_float as maybe_coerce_float
from ._utils import maybe_coerce_boolean as maybe_coerce_boolean
from ._utils import maybe_coerce_integer as maybe_coerce_integer
from ._utils import strip_annotated_type as strip_annotated_type
from ._transform import PropertyInfo as PropertyInfo
from ._transform import transform as transform
Expand Down
18 changes: 18 additions & 0 deletions src/modern_treasury/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,24 @@ def coerce_boolean(val: str) -> bool:
return val == "true" or val == "1" or val == "on"


def maybe_coerce_integer(val: str | None) -> int | None:
if val is None:
return None
return coerce_integer(val)


def maybe_coerce_float(val: str | None) -> float | None:
if val is None:
return None
return coerce_float(val)


def maybe_coerce_boolean(val: str | None) -> bool | None:
if val is None:
return None
return coerce_boolean(val)


def removeprefix(string: str, prefix: str) -> str:
"""Remove a prefix from a string.
Expand Down
11 changes: 6 additions & 5 deletions tests/api_resources/internal_accounts/test_balance_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
from modern_treasury.types.internal_accounts import BalanceReport

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = os.environ.get("API_KEY", "something1234")
api_key = "My API Key"
organization_id = "my-organization-ID"


class TestBalanceReports:
strict_client = ModernTreasury(
base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID"
base_url=base_url, api_key=api_key, organization_id=organization_id, _strict_response_validation=True
)
loose_client = ModernTreasury(
base_url=base_url, api_key=api_key, _strict_response_validation=False, organization_id="my-organization-ID"
base_url=base_url, api_key=api_key, organization_id=organization_id, _strict_response_validation=False
)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])

Expand Down Expand Up @@ -54,10 +55,10 @@ def test_method_list_with_all_params(self, client: ModernTreasury) -> None:

class TestAsyncBalanceReports:
strict_client = AsyncModernTreasury(
base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID"
base_url=base_url, api_key=api_key, organization_id=organization_id, _strict_response_validation=True
)
loose_client = AsyncModernTreasury(
base_url=base_url, api_key=api_key, _strict_response_validation=False, organization_id="my-organization-ID"
base_url=base_url, api_key=api_key, organization_id=organization_id, _strict_response_validation=False
)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])

Expand Down
11 changes: 6 additions & 5 deletions tests/api_resources/invoices/test_line_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
from modern_treasury.types.invoices import InvoiceLineItem

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = os.environ.get("API_KEY", "something1234")
api_key = "My API Key"
organization_id = "my-organization-ID"


class TestLineItems:
strict_client = ModernTreasury(
base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID"
base_url=base_url, api_key=api_key, organization_id=organization_id, _strict_response_validation=True
)
loose_client = ModernTreasury(
base_url=base_url, api_key=api_key, _strict_response_validation=False, organization_id="my-organization-ID"
base_url=base_url, api_key=api_key, organization_id=organization_id, _strict_response_validation=False
)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])

Expand Down Expand Up @@ -101,10 +102,10 @@ def test_method_delete(self, client: ModernTreasury) -> None:

class TestAsyncLineItems:
strict_client = AsyncModernTreasury(
base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID"
base_url=base_url, api_key=api_key, organization_id=organization_id, _strict_response_validation=True
)
loose_client = AsyncModernTreasury(
base_url=base_url, api_key=api_key, _strict_response_validation=False, organization_id="my-organization-ID"
base_url=base_url, api_key=api_key, organization_id=organization_id, _strict_response_validation=False
)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])

Expand Down
Loading

0 comments on commit dd134aa

Please sign in to comment.