Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

orjson optional wip #1223

Merged
merged 11 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions python/langsmith/_internal/_background_thread.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

Check notice on line 1 in python/langsmith/_internal/_background_thread.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... create_5_000_run_trees: Mean +- std dev: 545 ms +- 36 ms ......................................... create_10_000_run_trees: Mean +- std dev: 1.05 sec +- 0.04 sec ......................................... create_20_000_run_trees: Mean +- std dev: 1.05 sec +- 0.04 sec ......................................... dumps_class_nested_py_branch_and_leaf_200x400: Mean +- std dev: 703 us +- 10 us ......................................... dumps_class_nested_py_leaf_50x100: Mean +- std dev: 25.0 ms +- 0.3 ms ......................................... dumps_class_nested_py_leaf_100x200: Mean +- std dev: 104 ms +- 2 ms ......................................... dumps_dataclass_nested_50x100: Mean +- std dev: 25.3 ms +- 0.3 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (16.3 ms) is 25% of the mean (66.3 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. dumps_pydantic_nested_50x100: Mean +- std dev: 66.3 ms +- 16.3 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (30.8 ms) is 14% of the mean (222 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. dumps_pydanticv1_nested_50x100: Mean +- std dev: 222 ms +- 31 ms

Check notice on line 1 in python/langsmith/_internal/_background_thread.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------------+----------+------------------------+ | Benchmark | main | changes | +===============================================+==========+========================+ | create_10_000_run_trees | 1.43 sec | 1.05 sec: 1.36x faster | +-----------------------------------------------+----------+------------------------+ | create_20_000_run_trees | 1.40 sec | 1.05 sec: 1.33x faster | +-----------------------------------------------+----------+------------------------+ | create_5_000_run_trees | 723 ms | 545 ms: 1.33x faster | +-----------------------------------------------+----------+------------------------+ | dumps_pydantic_nested_50x100 | 71.3 ms | 66.3 ms: 1.07x faster | +-----------------------------------------------+----------+------------------------+ | dumps_dataclass_nested_50x100 | 25.7 ms | 25.3 ms: 1.02x faster | +-----------------------------------------------+----------+------------------------+ | dumps_class_nested_py_leaf_50x100 | 25.2 ms | 25.0 ms: 1.01x faster | +-----------------------------------------------+----------+------------------------+ | dumps_class_nested_py_branch_and_leaf_200x400 | 700 us | 703 us: 1.00x slower | +-----------------------------------------------+----------+------------------------+ | Geometric mean | (ref) | 1.12x faster | +-----------------------------------------------+----------+------------------------+ Benchmark hidden because not significant (2): dumps_pydanticv1_nested_50x100, dumps_class_nested_py_leaf_100x200

import functools
import logging
Expand Down Expand Up @@ -155,13 +155,24 @@
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
num_known_refs = 3

def keep_thread_active() -> bool:
# if `client.cleanup()` was called, stop thread
if client and client._manual_cleanup:
return False
if not threading.main_thread().is_alive():
# main thread is dead. should not be active
return False
try:
# check if client refs count indicates we're the only remaining
# reference to the client
return sys.getrefcount(client) > num_known_refs + len(sub_threads)
hinthornw marked this conversation as resolved.
Show resolved Hide resolved
except AttributeError:
# in PyPy, there is no sys.getrefcount attribute
# for now, keep thread alive
return True

# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we're the only remaining reference to the client
and sys.getrefcount(client) > num_known_refs + len(sub_threads)
):
while keep_thread_active():
for thread in sub_threads:
if not thread.is_alive():
sub_threads.remove(thread)
Expand Down
9 changes: 4 additions & 5 deletions python/langsmith/_internal/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import uuid
from typing import Literal, Optional, Union, cast

import orjson

from langsmith import schemas as ls_schemas
from langsmith._internal import _orjson
from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext
from langsmith._internal._serde import dumps_json as _dumps_json

Expand Down Expand Up @@ -169,12 +168,12 @@ def combine_serialized_queue_operations(
if op._none is not None and op._none != create_op._none:
# TODO optimize this more - this would currently be slowest
# for large payloads
create_op_dict = orjson.loads(create_op._none)
create_op_dict = _orjson.loads(create_op._none)
op_dict = {
k: v for k, v in orjson.loads(op._none).items() if v is not None
k: v for k, v in _orjson.loads(op._none).items() if v is not None
}
create_op_dict.update(op_dict)
create_op._none = orjson.dumps(create_op_dict)
create_op._none = _orjson.dumps(create_op_dict)

if op.inputs is not None:
create_op.inputs = op.inputs
Expand Down
84 changes: 84 additions & 0 deletions python/langsmith/_internal/_orjson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Stubs for orjson operations, compatible with PyPy via a json fallback."""

try:
from orjson import (
OPT_NON_STR_KEYS,
OPT_SERIALIZE_DATACLASS,
OPT_SERIALIZE_NUMPY,
OPT_SERIALIZE_UUID,
Fragment,
JSONDecodeError,
dumps,
loads,
)

except ImportError:
import dataclasses
import json
import uuid
from typing import Any, Callable, Optional

OPT_NON_STR_KEYS = 1
OPT_SERIALIZE_DATACLASS = 2
OPT_SERIALIZE_NUMPY = 4
OPT_SERIALIZE_UUID = 8

class Fragment: # type: ignore
def __init__(self, payloadb: bytes):
self.payloadb = payloadb

from json import JSONDecodeError # type: ignore

def dumps( # type: ignore
obj: Any,
/,
default: Optional[Callable[[Any], Any]] = None,
option: int = 0,
) -> bytes: # type: ignore
# for now, don't do anything for this case because `json.dumps`
# automatically encodes non-str keys as str by default, unlike orjson
# enable_non_str_keys = bool(option & OPT_NON_STR_KEYS)

enable_serialize_numpy = bool(option & OPT_SERIALIZE_NUMPY)
enable_serialize_dataclass = bool(option & OPT_SERIALIZE_DATACLASS)
enable_serialize_uuid = bool(option & OPT_SERIALIZE_UUID)

class CustomEncoder(json.JSONEncoder): # type: ignore
def encode(self, o: Any) -> str:
if isinstance(o, Fragment):
return o.payloadb.decode("utf-8") # type: ignore
return super().encode(o)

def default(self, o: Any) -> Any:
if enable_serialize_uuid and isinstance(o, uuid.UUID):
return str(o)
if enable_serialize_numpy and hasattr(o, "tolist"):
# even objects like np.uint16(15) have a .tolist() function
return o.tolist()
if (
enable_serialize_dataclass
and dataclasses.is_dataclass(o)
and not isinstance(o, type)
):
return dataclasses.asdict(o)
if default is not None:
return default(o)

return super().default(o)

return json.dumps(obj, cls=CustomEncoder).encode("utf-8")

def loads(payload: bytes, /) -> Any: # type: ignore
return json.loads(payload)


__all__ = [
"loads",
"dumps",
"Fragment",
"JSONDecodeError",
"OPT_SERIALIZE_NUMPY",
"OPT_SERIALIZE_DATACLASS",
"OPT_SERIALIZE_UUID",
"OPT_NON_STR_KEYS",
]
18 changes: 9 additions & 9 deletions python/langsmith/_internal/_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import uuid
from typing import Any

import orjson
from langsmith._internal import _orjson

try:
from zoneinfo import ZoneInfo # type: ignore[import-not-found]
Expand Down Expand Up @@ -133,13 +133,13 @@ def dumps_json(obj: Any) -> bytes:
The JSON formatted string.
"""
try:
return orjson.dumps(
return _orjson.dumps(
obj,
default=_serialize_json,
option=orjson.OPT_SERIALIZE_NUMPY
| orjson.OPT_SERIALIZE_DATACLASS
| orjson.OPT_SERIALIZE_UUID
| orjson.OPT_NON_STR_KEYS,
option=_orjson.OPT_SERIALIZE_NUMPY
| _orjson.OPT_SERIALIZE_DATACLASS
| _orjson.OPT_SERIALIZE_UUID
| _orjson.OPT_NON_STR_KEYS,
)
except TypeError as e:
# Usually caused by UTF surrogate characters
Expand All @@ -150,9 +150,9 @@ def dumps_json(obj: Any) -> bytes:
ensure_ascii=True,
).encode("utf-8")
try:
result = orjson.dumps(
orjson.loads(result.decode("utf-8", errors="surrogateescape"))
result = _orjson.dumps(
_orjson.loads(result.decode("utf-8", errors="surrogateescape"))
)
except orjson.JSONDecodeError:
except _orjson.JSONDecodeError:
result = _elide_surrogates(result)
return result
4 changes: 2 additions & 2 deletions python/langsmith/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload

import orjson
from typing_extensions import TypedDict

from langsmith import client as ls_client
Expand All @@ -21,6 +20,7 @@
from langsmith import run_trees as rt
from langsmith import schemas as ls_schemas
from langsmith import utils as ls_utils
from langsmith._internal import _orjson

try:
import pytest # type: ignore
Expand Down Expand Up @@ -374,7 +374,7 @@ def _serde_example_values(values: VT) -> VT:
if values is None:
return values
bts = ls_client._dumps_json(values)
return orjson.loads(bts)
return _orjson.loads(bts)


class _LangSmithTestSuite:
Expand Down
31 changes: 19 additions & 12 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
)
from urllib import parse as urllib_parse

import orjson
import requests
from requests import adapters as requests_adapters
from requests_toolbelt import ( # type: ignore[import-untyped]
Expand All @@ -69,6 +68,7 @@
from langsmith import env as ls_env
from langsmith import schemas as ls_schemas
from langsmith import utils as ls_utils
from langsmith._internal import _orjson
from langsmith._internal._background_thread import (
TracingQueueItem,
)
Expand Down Expand Up @@ -368,6 +368,7 @@ class Client:
"_info",
"_write_api_urls",
"_settings",
"_manual_cleanup",
]

def __init__(
Expand Down Expand Up @@ -516,6 +517,8 @@ def __init__(

self._settings: Union[ls_schemas.LangSmithSettings, None] = None

self._manual_cleanup = False

def _repr_html_(self) -> str:
"""Return an HTML representation of the instance with a link to the URL.

Expand Down Expand Up @@ -1252,7 +1255,7 @@ def _hide_run_inputs(self, inputs: dict):
if self._hide_inputs is True:
return {}
if self._anonymizer:
json_inputs = orjson.loads(_dumps_json(inputs))
json_inputs = _orjson.loads(_dumps_json(inputs))
return self._anonymizer(json_inputs)
if self._hide_inputs is False:
return inputs
Expand All @@ -1262,7 +1265,7 @@ def _hide_run_outputs(self, outputs: dict):
if self._hide_outputs is True:
return {}
if self._anonymizer:
json_outputs = orjson.loads(_dumps_json(outputs))
json_outputs = _orjson.loads(_dumps_json(outputs))
return self._anonymizer(json_outputs)
if self._hide_outputs is False:
return outputs
Expand All @@ -1282,20 +1285,20 @@ def _batch_ingest_run_ops(
# form the partial body and ids
for op in ops:
if isinstance(op, SerializedRunOperation):
curr_dict = orjson.loads(op._none)
curr_dict = _orjson.loads(op._none)
if op.inputs:
curr_dict["inputs"] = orjson.Fragment(op.inputs)
curr_dict["inputs"] = _orjson.Fragment(op.inputs)
if op.outputs:
curr_dict["outputs"] = orjson.Fragment(op.outputs)
curr_dict["outputs"] = _orjson.Fragment(op.outputs)
if op.events:
curr_dict["events"] = orjson.Fragment(op.events)
curr_dict["events"] = _orjson.Fragment(op.events)
if op.attachments:
logger.warning(
"Attachments are not supported when use_multipart_endpoint "
"is False"
)
ids_and_partial_body[op.operation].append(
(f"trace={op.trace_id},id={op.id}", orjson.dumps(curr_dict))
(f"trace={op.trace_id},id={op.id}", _orjson.dumps(curr_dict))
)
elif isinstance(op, SerializedFeedbackOperation):
logger.warning(
Expand All @@ -1321,20 +1324,20 @@ def _batch_ingest_run_ops(
and body_size + len(body_deque[0][1]) > size_limit_bytes
):
self._post_batch_ingest_runs(
orjson.dumps(body_chunks),
_orjson.dumps(body_chunks),
_context=f"\n{key}: {'; '.join(context_ids[key])}",
)
body_size = 0
body_chunks.clear()
context_ids.clear()
curr_id, curr_body = body_deque.popleft()
body_size += len(curr_body)
body_chunks[key].append(orjson.Fragment(curr_body))
body_chunks[key].append(_orjson.Fragment(curr_body))
context_ids[key].append(curr_id)
if body_size:
context = "; ".join(f"{k}: {'; '.join(v)}" for k, v in context_ids.items())
self._post_batch_ingest_runs(
orjson.dumps(body_chunks), _context="\n" + context
_orjson.dumps(body_chunks), _context="\n" + context
)

def batch_ingest_runs(
Expand Down Expand Up @@ -2759,7 +2762,7 @@ def create_dataset(
"POST",
"/datasets",
headers={**self._headers, "Content-Type": "application/json"},
data=orjson.dumps(dataset),
data=_orjson.dumps(dataset),
)
ls_utils.raise_for_status_with_text(response)

Expand Down Expand Up @@ -5675,6 +5678,10 @@ def push_prompt(
)
return url

def cleanup(self) -> None:
"""Manually trigger cleanup of the background thread."""
self._manual_cleanup = True


def convert_prompt_to_openai_format(
messages: Any,
Expand Down
4 changes: 2 additions & 2 deletions python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pydantic = [
{ version = "^2.7.4", python = ">=3.12.4" },
]
requests = "^2"
orjson = "^3.9.14"
orjson = { version = "^3.9.14", markers = "platform_python_implementation != 'PyPy'" }
httpx = ">=0.23.0,<1"
requests-toolbelt = "^1.0.0"

Expand Down
8 changes: 4 additions & 4 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from unittest.mock import MagicMock, patch

import dataclasses_json
import orjson
import pytest
import requests
from multipart import MultipartParser, MultipartPart, parse_options_header
Expand All @@ -33,6 +32,7 @@
import langsmith.utils as ls_utils
from langsmith import AsyncClient, EvaluationResult, run_trees
from langsmith import schemas as ls_schemas
from langsmith._internal import _orjson
from langsmith._internal._serde import _serialize_json
from langsmith.client import (
Client,
Expand Down Expand Up @@ -848,7 +848,7 @@ class MyNamedTuple(NamedTuple):
"set_with_class": set([MyClass(1)]),
"my_mock": MagicMock(text="Hello, world"),
}
res = orjson.loads(_dumps_json(to_serialize))
res = _orjson.loads(_dumps_json(to_serialize))
assert (
"model_dump" not in caplog.text
), f"Unexpected error logs were emitted: {caplog.text}"
Expand Down Expand Up @@ -898,7 +898,7 @@ def __repr__(self) -> str:
my_cyclic = CyclicClass(other=CyclicClass(other=None))
my_cyclic.other.other = my_cyclic # type: ignore

res = orjson.loads(_dumps_json({"cyclic": my_cyclic}))
res = _orjson.loads(_dumps_json({"cyclic": my_cyclic}))
assert res == {"cyclic": "my_cycles..."}
expected = {"foo": "foo", "bar": 1}

Expand Down Expand Up @@ -1142,7 +1142,7 @@ def test_batch_ingest_run_splits_large_batches(
op
for call in mock_session.request.call_args_list
for reqs in (
orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else []
_orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else []
)
for op in reqs
]
Expand Down
Loading
Loading