diff --git a/src/openai/_models.py b/src/openai/_models.py index d6f42d3d4d..710401defd 100644 --- a/src/openai/_models.py +++ b/src/openai/_models.py @@ -2,7 +2,7 @@ import os import inspect -from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast +from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast from datetime import date, datetime from typing_extensions import ( Unpack, @@ -10,6 +10,7 @@ ClassVar, Protocol, Required, + Sequence, ParamSpec, TypedDict, TypeGuard, @@ -72,6 +73,8 @@ P = ParamSpec("P") +ReprArgs = Sequence[Tuple[Optional[str], Any]] + @runtime_checkable class _ConfigProtocol(Protocol): @@ -94,6 +97,11 @@ def model_fields_set(self) -> set[str]: class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] extra: Any = pydantic.Extra.allow # type: ignore + @override + def __repr_args__(self) -> ReprArgs: + # we don't want these attributes to be included when something like `rich.print` is used + return [arg for arg in super().__repr_args__() if arg[0] not in {"_request_id", "__exclude_fields__"}] + if TYPE_CHECKING: _request_id: Optional[str] = None """The ID of the request, returned via the X-Request-ID header. Useful for debugging requests and reporting issues to OpenAI. diff --git a/tests/lib/chat/_utils.py b/tests/lib/chat/_utils.py index dcc32b17fd..af08db417c 100644 --- a/tests/lib/chat/_utils.py +++ b/tests/lib/chat/_utils.py @@ -1,14 +1,14 @@ from __future__ import annotations -import io import inspect from typing import Any, Iterable from typing_extensions import TypeAlias -import rich import pytest import pydantic +from ...utils import rich_print_str + ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]" @@ -26,12 +26,7 @@ def __repr_args__(self: pydantic.BaseModel) -> ReprArgs: with monkeypatch.context() as m: m.setattr(pydantic.BaseModel, "__repr_args__", __repr_args__) - buf = io.StringIO() - - console = rich.console.Console(file=buf, width=120) - console.print(obj) - - string = buf.getvalue() + string = rich_print_str(obj) # we remove all `fn_name..` occurences # so that we can share the same snapshots between diff --git a/tests/test_legacy_response.py b/tests/test_legacy_response.py index a6fec9f2de..f50a77c24d 100644 --- a/tests/test_legacy_response.py +++ b/tests/test_legacy_response.py @@ -11,6 +11,8 @@ from openai._base_client import FinalRequestOptions from openai._legacy_response import LegacyAPIResponse +from .utils import rich_print_str + class PydanticModel(pydantic.BaseModel): ... @@ -85,6 +87,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None: assert obj.foo == "hello!" assert obj.bar == 2 assert obj.to_dict() == {"foo": "hello!", "bar": 2} + assert "_request_id" not in rich_print_str(obj) + assert "__exclude_fields__" not in rich_print_str(obj) def test_response_parse_annotated_type(client: OpenAI) -> None: diff --git a/tests/test_response.py b/tests/test_response.py index 97c56e0035..e1fe332f2f 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -18,6 +18,8 @@ from openai._streaming import Stream from openai._base_client import FinalRequestOptions +from .utils import rich_print_str + class ConcreteBaseAPIResponse(APIResponse[bytes]): ... @@ -175,6 +177,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None: assert obj.foo == "hello!" assert obj.bar == 2 assert obj.to_dict() == {"foo": "hello!", "bar": 2} + assert "_request_id" not in rich_print_str(obj) + assert "__exclude_fields__" not in rich_print_str(obj) @pytest.mark.asyncio diff --git a/tests/utils.py b/tests/utils.py index 165f4e5bfd..8d5397f28e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import os import inspect import traceback @@ -8,6 +9,8 @@ from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type +import rich + from openai._types import Omit, NoneType from openai._utils import ( is_dict, @@ -138,6 +141,16 @@ def _assert_list_type(type_: type[object], value: object) -> None: assert_type(inner_type, entry) # type: ignore +def rich_print_str(obj: object) -> str: + """Like `rich.print()` but returns the string instead""" + buf = io.StringIO() + + console = rich.console.Console(file=buf, width=120) + console.print(obj) + + return buf.getvalue() + + @contextlib.contextmanager def update_env(**new_env: str | Omit) -> Iterator[None]: old = os.environ.copy()