From c0af2ddd92823af2c24fe1af6eeac559e4bf9ed4 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Fri, 28 Jul 2023 22:04:15 +0100 Subject: [PATCH] feat(client): add client close handlers (#157) --- README.md | 4 ++ src/modern_treasury/_base_client.py | 43 +++++++++++++++++ src/modern_treasury/_client.py | 10 ++++ tests/test_client.py | 74 ++++++++++++++++------------- 4 files changed, 99 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 3ca50e24..f7541dbe 100644 --- a/README.md +++ b/README.md @@ -323,6 +323,10 @@ client = ModernTreasury( See the httpx documentation for information about the [`proxies`](https://www.python-httpx.org/advanced/#http-proxying) and [`transport`](https://www.python-httpx.org/advanced/#custom-transports) keyword arguments. +## Advanced: Managing HTTP resources + +By default we will close the underlying HTTP connections whenever the client is [garbage collected](https://docs.python.org/3/reference/datamodel.html#object.__del__) is called but you can also manually close the client using the `.close()` method if desired, or with a context manager that closes when exiting. + # Migration guide This section outlines the features that were deprecated in `v0.5.0`, and subsequently removed in `v0.6.0` and how to migrate your code. diff --git a/src/modern_treasury/_base_client.py b/src/modern_treasury/_base_client.py index 6b4cc414..d14adc8d 100644 --- a/src/modern_treasury/_base_client.py +++ b/src/modern_treasury/_base_client.py @@ -5,6 +5,7 @@ import uuid import inspect import platform +from types import TracebackType from random import random from typing import ( Any, @@ -677,6 +678,27 @@ def __init__( headers={"Accept": "application/json"}, ) + def is_closed(self) -> bool: + return self._client.is_closed + + def close(self) -> None: + """Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + self._client.close() + + def __enter__(self: _T) -> _T: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + @overload def request( self, @@ -1009,6 +1031,27 @@ def __init__( headers={"Accept": "application/json"}, ) + def is_closed(self) -> bool: + return self._client.is_closed + + async def close(self) -> None: + """Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + await self._client.aclose() + + async def __aenter__(self: _T) -> _T: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + @overload async def request( self, diff --git a/src/modern_treasury/_client.py b/src/modern_treasury/_client.py index 3f1c5959..6c418903 100644 --- a/src/modern_treasury/_client.py +++ b/src/modern_treasury/_client.py @@ -4,6 +4,7 @@ import os import base64 +import asyncio from typing import Union, Mapping, Optional import httpx @@ -256,6 +257,9 @@ def copy( # client.with_options(timeout=10).foo.create(...) with_options = copy + def __del__(self) -> None: + self.close() + def ping( self, *, @@ -488,6 +492,12 @@ def copy( # client.with_options(timeout=10).foo.create(...) with_options = copy + def __del__(self) -> None: + try: + asyncio.get_running_loop().create_task(self.close()) + except Exception: + pass + async def ping( self, *, diff --git a/tests/test_client.py b/tests/test_client.py index 8df02762..2a4cca10 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,6 +4,7 @@ import os import json +import asyncio import inspect from typing import Any, Dict, Union, cast @@ -182,22 +183,6 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" - def test_validate_headers(self) -> None: - client = ModernTreasury( - base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID" - ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) - assert "Basic" in request.headers.get("Authorization") - - with pytest.raises( - Exception, - match="The api_key client option must be set either by passing api_key to the client or by setting the MODERN_TREASURY_API_KEY environment variable", - ): - client2 = ModernTreasury( - base_url=base_url, api_key=None, _strict_response_validation=True, organization_id="my-organization-ID" - ) - client2._build_request(FinalRequestOptions(method="get", url="/foo")) - def test_default_query_option(self) -> None: client = ModernTreasury( base_url=base_url, @@ -416,6 +401,26 @@ def test_base_url_no_trailing_slash(self) -> None: ) assert request.url == "http://localhost:5000/custom/path/foo" + def test_client_del(self) -> None: + client = ModernTreasury( + base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID" + ) + assert not client.is_closed() + + client.__del__() + + assert client.is_closed() + + def test_client_context_manager(self) -> None: + client = ModernTreasury( + base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID" + ) + with client as c2: + assert c2 is client + assert not c2.is_closed() + assert not client.is_closed() + assert client.is_closed() + class TestAsyncModernTreasury: client = AsyncModernTreasury( @@ -574,22 +579,6 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" - def test_validate_headers(self) -> None: - client = AsyncModernTreasury( - base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID" - ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) - assert "Basic" in request.headers.get("Authorization") - - with pytest.raises( - Exception, - match="The api_key client option must be set either by passing api_key to the client or by setting the MODERN_TREASURY_API_KEY environment variable", - ): - client2 = AsyncModernTreasury( - base_url=base_url, api_key=None, _strict_response_validation=True, organization_id="my-organization-ID" - ) - client2._build_request(FinalRequestOptions(method="get", url="/foo")) - def test_default_query_option(self) -> None: client = AsyncModernTreasury( base_url=base_url, @@ -807,3 +796,24 @@ def test_base_url_no_trailing_slash(self) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + + async def test_client_del(self) -> None: + client = AsyncModernTreasury( + base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID" + ) + assert not client.is_closed() + + client.__del__() + + await asyncio.sleep(0.2) + assert client.is_closed() + + async def test_client_context_manager(self) -> None: + client = AsyncModernTreasury( + base_url=base_url, api_key=api_key, _strict_response_validation=True, organization_id="my-organization-ID" + ) + async with client as c2: + assert c2 is client + assert not c2.is_closed() + assert not client.is_closed() + assert client.is_closed()