-
Notifications
You must be signed in to change notification settings - Fork 16.1k
/
Copy pathlangchain.py
325 lines (275 loc) Β· 10.7 KB
/
langchain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
"""A Tracer implementation that records to LangChain endpoint."""
from __future__ import annotations
import copy
import logging
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import UUID
from langsmith import Client
from langsmith import run_trees as rt
from langsmith import utils as ls_utils
from pydantic import PydanticDeprecationWarning
from tenacity import (
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
logger = logging.getLogger(__name__)
_LOGGED = set()
_EXECUTOR: Optional[ThreadPoolExecutor] = None
def log_error_once(method: str, exception: Exception) -> None:
"""Log an error once.
Args:
method: The method that raised the exception.
exception: The exception that was raised.
"""
global _LOGGED
if (method, type(exception)) in _LOGGED:
return
_LOGGED.add((method, type(exception)))
logger.error(exception)
def wait_for_all_tracers() -> None:
"""Wait for all tracers to finish."""
if rt._CLIENT is not None and rt._CLIENT.tracing_queue is not None:
rt._CLIENT.tracing_queue.join()
def get_client() -> Client:
"""Get the client."""
return rt.get_cached_client()
def _get_executor() -> ThreadPoolExecutor:
"""Get the executor."""
global _EXECUTOR
if _EXECUTOR is None:
_EXECUTOR = ThreadPoolExecutor()
return _EXECUTOR
def _run_to_dict(run: Run) -> dict:
# TODO: Update once langsmith moves to Pydantic V2 and we can swap run.dict for
# run.model_dump
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
return {
**run.dict(exclude={"child_runs", "inputs", "outputs"}),
"inputs": run.inputs.copy() if run.inputs is not None else None,
"outputs": run.outputs.copy() if run.outputs is not None else None,
}
class LangChainTracer(BaseTracer):
"""Implementation of the SharedTracer that POSTS to the LangChain endpoint."""
run_inline = True
def __init__(
self,
example_id: Optional[Union[UUID, str]] = None,
project_name: Optional[str] = None,
client: Optional[Client] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Initialize the LangChain tracer.
Args:
example_id: The example ID.
project_name: The project name. Defaults to the tracer project.
client: The client. Defaults to the global client.
tags: The tags. Defaults to an empty list.
kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.project_name = project_name or ls_utils.get_tracer_project()
self.client = client or get_client()
self.tags = tags or []
self.latest_run: Optional[Run] = None
def _start_trace(self, run: Run) -> None:
if self.project_name:
run.session_name = self.project_name
if self.tags is not None:
if run.tags:
run.tags = sorted(set(run.tags + self.tags))
else:
run.tags = self.tags.copy()
super()._start_trace(run)
if run._client is None:
run._client = self.client # type: ignore
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run.
Args:
serialized: The serialized model.
messages: The messages.
run_id: The run ID.
tags: The tags. Defaults to None.
parent_run_id: The parent run ID. Defaults to None.
metadata: The metadata. Defaults to None.
name: The name. Defaults to None.
kwargs: Additional keyword arguments.
Returns:
Run: The run.
"""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
chat_model_run = Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
run_type="llm",
tags=tags,
name=name, # type: ignore[arg-type]
)
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)
return chat_model_run
def _persist_run(self, run: Run) -> None:
# TODO: Update once langsmith moves to Pydantic V2 and we can swap run.copy for
# run.model_copy
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
run_ = copy.copy(run)
run_.reference_example_id = self.example_id
self.latest_run = run_
def get_run_url(self) -> str:
"""Get the LangSmith root run URL.
Returns:
str: The LangSmith root run URL.
Raises:
ValueError: If no traced run is found.
ValueError: If the run URL cannot be found.
"""
if not self.latest_run:
msg = "No traced run found."
raise ValueError(msg)
# If this is the first run in a project, the project may not yet be created.
# This method is only really useful for debugging flows, so we will assume
# there is some tolerace for latency.
for attempt in Retrying(
stop=stop_after_attempt(5),
wait=wait_exponential_jitter(),
retry=retry_if_exception_type(ls_utils.LangSmithError),
):
with attempt:
return self.client.get_run_url(
run=self.latest_run, project_name=self.project_name
)
msg = "Failed to get run URL."
raise ValueError(msg)
def _get_tags(self, run: Run) -> list[str]:
"""Get combined tags for a run."""
tags = set(run.tags or [])
tags.update(self.tags or [])
return list(tags)
def _persist_run_single(self, run: Run) -> None:
"""Persist a run."""
run_dict = _run_to_dict(run)
run_dict["tags"] = self._get_tags(run)
extra = run_dict.get("extra", {})
extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra
try:
self.client.create_run(**run_dict, project_name=self.project_name)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("post", e)
raise
def _update_run_single(self, run: Run) -> None:
"""Update a run."""
try:
run_dict = _run_to_dict(run)
run_dict["tags"] = self._get_tags(run)
self.client.update_run(run.id, **run_dict)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("patch", e)
raise
def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._persist_run_single(run)
def _llm_run_with_token_event(
self,
token: str,
run_id: UUID,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Run:
"""Append token event to LLM run and return the run."""
return super()._llm_run_with_token_event(
# Drop the chunk; we don't need to save it
token,
run_id,
chunk=None,
parent_run_id=parent_run_id,
**kwargs,
)
def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._persist_run_single(run)
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
self._update_run_single(run)
def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
self._update_run_single(run)
def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._persist_run_single(run)
def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
self._update_run_single(run)
def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
self._update_run_single(run)
def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._persist_run_single(run)
def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
self._update_run_single(run)
def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self._update_run_single(run)
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._persist_run_single(run)
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
self._update_run_single(run)
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""
self._update_run_single(run)
def wait_for_futures(self) -> None:
"""Wait for the given futures to complete."""
if self.client is not None and self.client.tracing_queue is not None:
self.client.tracing_queue.join()