Skip to content

Commit

Permalink
feat: Add get_chain_root_span utility for langchain instrumentation (
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator authored Oct 17, 2024
1 parent ac09490 commit 4337aa1
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import TYPE_CHECKING, Any, Callable, Collection, Optional
from typing import TYPE_CHECKING, Any, Callable, Collection, List, Optional
from uuid import UUID

from opentelemetry import trace as trace_api
Expand Down Expand Up @@ -64,6 +64,28 @@ def _uninstrument(self, **kwargs: Any) -> None:
def get_span(self, run_id: UUID) -> Optional[Span]:
return self._tracer.get_span(run_id) if self._tracer else None

def get_ancestors(self, run_id: UUID) -> List[Span]:
ancestors: List[Span] = []
tracer = self._tracer
assert tracer

run = tracer.run_map.get(str(run_id))
if not run:
return ancestors

ancestor_run_id = run.parent_run_id # start with the first ancestor

while ancestor_run_id:
span = self.get_span(ancestor_run_id)
if span:
ancestors.append(span)

run = tracer.run_map.get(str(ancestor_run_id))
if not run:
break
ancestor_run_id = run.parent_run_id
return ancestors


class _BaseCallbackManagerInit:
__slots__ = ("_tracer",)
Expand Down Expand Up @@ -104,3 +126,27 @@ def get_current_span() -> Optional[Span]:
if not run_id:
return None
return LangChainInstrumentor().get_span(run_id)


def get_ancestor_spans() -> List[Span]:
"""
Retrieve the ancestor spans for the current LangChain run.
This function traverses the LangChain run tree from the current run's parent up to the root,
collecting the spans associated with each ancestor run. The list is ordered from the immediate
parent of the current run to the root of the run tree.
"""
import langchain_core

run_id: Optional[UUID] = None
config = langchain_core.runnables.config.var_child_runnable_config.get()
if not isinstance(config, dict):
return None
for v in config.values():
if not isinstance(v, langchain_core.callbacks.BaseCallbackManager):
continue
if run_id := v.parent_run_id:
break
if not run_id:
return []
return LangChainInstrumentor().get_ancestors(run_id)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from opentelemetry import trace as trace_api
from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import Span
from opentelemetry.util.types import AttributeValue
from wrapt import ObjectProxy

Expand Down Expand Up @@ -114,10 +115,10 @@ def __init__(self, tracer: trace_api.Tracer, *args: Any, **kwargs: Any) -> None:
assert self.run_map
self.run_map = _DictWithLock[str, Run](self.run_map)
self._tracer = tracer
self._spans_by_run: Dict[UUID, trace_api.Span] = _DictWithLock[UUID, trace_api.Span]()
self._spans_by_run: Dict[UUID, Span] = _DictWithLock[UUID, Span]()
self._lock = RLock() # handlers may be run in a thread by langchain

def get_span(self, run_id: UUID) -> Optional[trace_api.Span]:
def get_span(self, run_id: UUID) -> Optional[Span]:
return self._spans_by_run.get(run_id)

@audit_timing # type: ignore
Expand All @@ -140,6 +141,7 @@ def _start_trace(self, run: Run) -> None:
context=parent_context,
start_time=start_time_utc_nano,
)

# The following line of code is commented out to serve as a reminder that in a system
# of callbacks, attaching the context can be hazardous because there is no guarantee
# that the context will be detached. An error could happen between callbacks leaving
Expand Down Expand Up @@ -207,7 +209,7 @@ def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:


@audit_timing # type: ignore
def _record_exception(span: trace_api.Span, error: BaseException) -> None:
def _record_exception(span: Span, error: BaseException) -> None:
if isinstance(error, Exception):
span.record_exception(error)
return
Expand All @@ -229,7 +231,7 @@ def _record_exception(span: trace_api.Span, error: BaseException) -> None:


@audit_timing # type: ignore
def _update_span(span: trace_api.Span, run: Run) -> None:
def _update_span(span: Span, run: Run) -> None:
if run.error is None:
span.set_status(trace_api.StatusCode.OK)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from langchain_community.embeddings import FakeEmbeddings
from langchain_community.retrievers import KNNRetriever
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableLambda, RunnableSerializable
from langchain_openai import ChatOpenAI
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import ReadableSpan
Expand All @@ -39,7 +39,10 @@
from respx import MockRouter

from openinference.instrumentation import using_attributes
from openinference.instrumentation.langchain import get_current_span
from openinference.instrumentation.langchain import (
get_ancestor_spans,
get_current_span,
)
from openinference.semconv.trace import (
DocumentAttributes,
EmbeddingAttributes,
Expand Down Expand Up @@ -92,6 +95,94 @@ async def f(_: Any) -> Optional[Span]:
}


async def test_get_ancestor_spans(
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
"""Test retrieving the current chain root span during RunnableLambda execution."""
n = 10 # Number of concurrent runs
loop = asyncio.get_running_loop()

ancestors_during_execution = []

def f(x: int) -> int:
current_span = get_current_span()
root_spans = get_ancestor_spans()
assert root_spans is not None, "Ancestor should not be None during execution (async)"
assert len(root_spans) == 1, "Only get ancestor spans"
assert current_span is not root_spans[0], "Ancestor is distinct from the current span"
ancestors_during_execution.append(root_spans[0])
assert (
root_spans[0].name == "RunnableSequence" # type: ignore[attr-defined, unused-ignore]
), "RunnableSequence should be the outermost ancestor"
return x + 1

sequence: RunnableSerializable[int, int] = RunnableLambda[int, int](f) | RunnableLambda[
int, int
](f)

with ThreadPoolExecutor() as executor:
tasks = [loop.run_in_executor(executor, sequence.invoke, 1) for _ in range(n)]
await asyncio.gather(*tasks)

ancestors_after_execution = get_ancestor_spans()
assert ancestors_after_execution == [], "No ancestors after execution"

assert (
len(ancestors_during_execution) == 2 * n
), "Did not capture all ancestors during execution"

assert (
len(set(id(span) for span in ancestors_during_execution)) == n
), "Both Lambdas share the same ancestor"

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 3 * n, f"Expected {3 * n} spans, but found {len(spans)}"


async def test_get_ancestor_spans_async(
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
"""Test retrieving the current chain root span during RunnableLambda execution."""
if sys.version_info < (3, 11):
pytest.xfail("Async test may fail on Python versions below 3.11")
n = 10 # Number of concurrent runs

ancestors_during_execution = []

async def f(x: int) -> int:
current_span = get_current_span()
root_spans = get_ancestor_spans()
assert root_spans is not None, "Ancestor should not be None during execution (async)"
assert len(root_spans) == 1, "Only get ancestor spans"
assert current_span is not root_spans[0], "Ancestor is distinct from the current span"
ancestors_during_execution.append(root_spans[0])
assert (
root_spans[0].name == "RunnableSequence" # type: ignore[attr-defined, unused-ignore]
), "RunnableSequence should be the outermost ancestor"
await asyncio.sleep(0.01)
return x + 1

sequence: RunnableSerializable[int, int] = RunnableLambda[int, int](f) | RunnableLambda[
int, int
](f)

await asyncio.gather(*(sequence.ainvoke(1) for _ in range(n)))

ancestors_after_execution = get_ancestor_spans()
assert ancestors_after_execution == [], "No ancestors after execution"

assert (
len(ancestors_during_execution) == 2 * n
), "Did not capture all ancestors during execution"

assert (
len(set(id(span) for span in ancestors_during_execution)) == n
), "Both Lambdas share the same ancestor"

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 3 * n, f"Expected {3 * n} spans, but found {len(spans)}"


@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize("is_stream", [False, True])
@pytest.mark.parametrize("status_code", [200, 400])
Expand Down

0 comments on commit 4337aa1

Please sign in to comment.