Skip to content

Commit

Permalink
fix: avoid leaking memory when Client.with_options is used (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Jan 2, 2024
1 parent f80056b commit abf6406
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 17 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ select = [
"T203",
]
ignore = [
# lru_cache in methods, will be fixed separately
"B019",
# mutable defaults
"B006",
]
Expand Down
28 changes: 15 additions & 13 deletions src/modern_treasury/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,12 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
headers_dict = _merge_mappings(self.default_headers, custom_headers)
self._validate_headers(headers_dict, custom_headers)

# headers are case-insensitive while dictionaries are not.
headers = httpx.Headers(headers_dict)

idempotency_header = self._idempotency_header
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
if not options.idempotency_key:
options.idempotency_key = self._idempotency_key()

headers[idempotency_header] = options.idempotency_key
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()

return headers

Expand Down Expand Up @@ -594,16 +592,8 @@ def base_url(self) -> URL:
def base_url(self, url: URL | str) -> None:
self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))

@lru_cache(maxsize=None)
def platform_headers(self) -> Dict[str, str]:
return {
"X-Stainless-Lang": "python",
"X-Stainless-Package-Version": self._version,
"X-Stainless-OS": str(get_platform()),
"X-Stainless-Arch": str(get_architecture()),
"X-Stainless-Runtime": platform.python_implementation(),
"X-Stainless-Runtime-Version": platform.python_version(),
}
return platform_headers(self._version)

def _calculate_retry_timeout(
self,
Expand Down Expand Up @@ -1691,6 +1681,18 @@ def get_platform() -> Platform:
return "Unknown"


@lru_cache(maxsize=None)
def platform_headers(version: str) -> Dict[str, str]:
return {
"X-Stainless-Lang": "python",
"X-Stainless-Package-Version": version,
"X-Stainless-OS": str(get_platform()),
"X-Stainless-Arch": str(get_architecture()),
"X-Stainless-Runtime": platform.python_implementation(),
"X-Stainless-Runtime-Version": platform.python_version(),
}


class OtherArch:
def __init__(self, name: str) -> None:
self.name = name
Expand Down
4 changes: 2 additions & 2 deletions src/modern_treasury/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def copy(
api_key=api_key or self.api_key,
organization_id=organization_id or self.organization_id,
webhook_key=webhook_key or self.webhook_key,
base_url=base_url or str(self.base_url),
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
connection_pool_limits=connection_pool_limits,
Expand Down Expand Up @@ -609,7 +609,7 @@ def copy(
api_key=api_key or self.api_key,
organization_id=organization_id or self.organization_id,
webhook_key=webhook_key or self.webhook_key,
base_url=base_url or str(self.base_url),
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
connection_pool_limits=connection_pool_limits,
Expand Down
124 changes: 124 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import gc
import os
import json
import asyncio
import inspect
import tracemalloc
from typing import Any, Union, cast
from unittest import mock

Expand Down Expand Up @@ -213,6 +215,67 @@ def test_copy_signature(self) -> None:
copy_param = copy_signature.parameters.get(name)
assert copy_param is not None, f"copy() signature is missing the {name} param"

def test_copy_build_request(self) -> None:
options = FinalRequestOptions(method="get", url="/foo")

def build_request(options: FinalRequestOptions) -> None:
client = self.client.copy()
client._build_request(options)

# ensure that the machinery is warmed up before tracing starts.
build_request(options)
gc.collect()

tracemalloc.start(1000)

snapshot_before = tracemalloc.take_snapshot()

ITERATIONS = 10
for _ in range(ITERATIONS):
build_request(options)
gc.collect()

snapshot_after = tracemalloc.take_snapshot()

tracemalloc.stop()

def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
if diff.count == 0:
# Avoid false positives by considering only leaks (i.e. allocations that persist).
return

if diff.count % ITERATIONS != 0:
# Avoid false positives by considering only leaks that appear per iteration.
return

for frame in diff.traceback:
if any(
frame.filename.endswith(fragment)
for fragment in [
# to_raw_response_wrapper leaks through the @functools.wraps() decorator.
#
# removing the decorator fixes the leak for reasons we don't understand.
"modern_treasury/_response.py",
# pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
"modern_treasury/_compat.py",
# Standard library leaks we don't care about.
"/logging/__init__.py",
]
):
return

leaks.append(diff)

leaks: list[tracemalloc.StatisticDiff] = []
for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
add_leak(leaks, diff)
if leaks:
for leak in leaks:
print("MEMORY LEAK:", leak)
for frame in leak.traceback:
print(frame)
raise AssertionError()

def test_request_timeout(self) -> None:
request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
Expand Down Expand Up @@ -1061,6 +1124,67 @@ def test_copy_signature(self) -> None:
copy_param = copy_signature.parameters.get(name)
assert copy_param is not None, f"copy() signature is missing the {name} param"

def test_copy_build_request(self) -> None:
options = FinalRequestOptions(method="get", url="/foo")

def build_request(options: FinalRequestOptions) -> None:
client = self.client.copy()
client._build_request(options)

# ensure that the machinery is warmed up before tracing starts.
build_request(options)
gc.collect()

tracemalloc.start(1000)

snapshot_before = tracemalloc.take_snapshot()

ITERATIONS = 10
for _ in range(ITERATIONS):
build_request(options)
gc.collect()

snapshot_after = tracemalloc.take_snapshot()

tracemalloc.stop()

def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
if diff.count == 0:
# Avoid false positives by considering only leaks (i.e. allocations that persist).
return

if diff.count % ITERATIONS != 0:
# Avoid false positives by considering only leaks that appear per iteration.
return

for frame in diff.traceback:
if any(
frame.filename.endswith(fragment)
for fragment in [
# to_raw_response_wrapper leaks through the @functools.wraps() decorator.
#
# removing the decorator fixes the leak for reasons we don't understand.
"modern_treasury/_response.py",
# pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
"modern_treasury/_compat.py",
# Standard library leaks we don't care about.
"/logging/__init__.py",
]
):
return

leaks.append(diff)

leaks: list[tracemalloc.StatisticDiff] = []
for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
add_leak(leaks, diff)
if leaks:
for leak in leaks:
print("MEMORY LEAK:", leak)
for frame in leak.traceback:
print(frame)
raise AssertionError()

async def test_request_timeout(self) -> None:
request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
Expand Down

0 comments on commit abf6406

Please sign in to comment.