Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(client): simplify cleanup #966

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ typecheck = { chain = [
]}
"typecheck:pyright" = "pyright"
"typecheck:verify-types" = "pyright --verifytypes openai --ignoreexternal"
"typecheck:mypy" = "mypy --enable-incomplete-feature=Unpack ."
"typecheck:mypy" = "mypy ."

[build-system]
requires = ["hatchling"]
Expand Down
7 changes: 0 additions & 7 deletions src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,6 @@ def _client(self, value: _httpx.Client) -> None: # type: ignore

http_client = value

@override
def __del__(self) -> None:
try:
super().__del__()
except Exception:
pass


class _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore
...
Expand Down
26 changes: 20 additions & 6 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import uuid
import email
import asyncio
import inspect
import logging
import platform
Expand Down Expand Up @@ -672,9 +673,16 @@ def _idempotency_key(self) -> str:
return f"stainless-python-retry-{uuid.uuid4()}"


class SyncHttpxClientWrapper(httpx.Client):
def __del__(self) -> None:
try:
self.close()
except Exception:
pass


class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
_client: httpx.Client
_has_custom_http_client: bool
_default_stream_cls: type[Stream[Any]] | None = None

def __init__(
Expand Down Expand Up @@ -747,15 +755,14 @@ def __init__(
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
)
self._client = http_client or httpx.Client(
self._client = http_client or SyncHttpxClientWrapper(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
proxies=proxies,
transport=transport,
limits=limits,
)
self._has_custom_http_client = bool(http_client)

def is_closed(self) -> bool:
return self._client.is_closed
Expand Down Expand Up @@ -1135,9 +1142,17 @@ def get_api_list(
return self._request_api_list(model, page, opts)


class AsyncHttpxClientWrapper(httpx.AsyncClient):
def __del__(self) -> None:
try:
# TODO(someday): support non asyncio runtimes here
asyncio.get_running_loop().create_task(self.aclose())
except Exception:
pass


class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
_client: httpx.AsyncClient
_has_custom_http_client: bool
_default_stream_cls: type[AsyncStream[Any]] | None = None

def __init__(
Expand Down Expand Up @@ -1210,15 +1225,14 @@ def __init__(
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
)
self._client = http_client or httpx.AsyncClient(
self._client = http_client or AsyncHttpxClientWrapper(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
proxies=proxies,
transport=transport,
limits=limits,
)
self._has_custom_http_client = bool(http_client)

def is_closed(self) -> bool:
return self._client.is_closed
Expand Down
24 changes: 0 additions & 24 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import os
import asyncio
from typing import Any, Union, Mapping
from typing_extensions import Self, override

Expand Down Expand Up @@ -205,16 +204,6 @@ def copy(
# client.with_options(timeout=10).foo.create(...)
with_options = copy

def __del__(self) -> None:
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close"):
# this can happen if the '__init__' method raised an error
return

if self._has_custom_http_client:
return

self.close()

@override
def _make_status_error(
self,
Expand Down Expand Up @@ -415,19 +404,6 @@ def copy(
# client.with_options(timeout=10).foo.create(...)
with_options = copy

def __del__(self) -> None:
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close"):
# this can happen if the '__init__' method raised an error
return

if self._has_custom_http_client:
return

try:
asyncio.get_running_loop().create_task(self.close())
except Exception:
pass

@override
def _make_status_error(
self,
Expand Down
23 changes: 2 additions & 21 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,24 +591,15 @@ def test_absolute_request_url(self, client: OpenAI) -> None:
)
assert request.url == "https://myapi.com/foo"

def test_client_del(self) -> None:
client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
assert not client.is_closed()

client.__del__()

assert client.is_closed()

def test_copied_client_does_not_close_http(self) -> None:
client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
assert not client.is_closed()

copied = client.copy()
assert copied is not client

copied.__del__()
del copied

assert not copied.is_closed()
assert not client.is_closed()

def test_client_context_manager(self) -> None:
Expand Down Expand Up @@ -1325,26 +1316,16 @@ def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
)
assert request.url == "https://myapi.com/foo"

async def test_client_del(self) -> None:
client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
assert not client.is_closed()

client.__del__()

await asyncio.sleep(0.2)
assert client.is_closed()

async def test_copied_client_does_not_close_http(self) -> None:
client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
assert not client.is_closed()

copied = client.copy()
assert copied is not client

copied.__del__()
del copied

await asyncio.sleep(0.2)
assert not copied.is_closed()
assert not client.is_closed()

async def test_client_context_manager(self) -> None:
Expand Down