Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llmobs): submit spans for streamed calls #10908

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,8 +954,6 @@ def _on_span_started(span: Span):
span.set_tag_str("langchain.request.inputs.%d.%s" % (idx, k), integration.trunc(str(v)))

def _on_span_finished(span: Span, streamed_chunks):
if span.error or not integration.is_pc_sampled_span(span):
return
if (
streamed_chunks
and langchain_core
Expand All @@ -970,6 +968,9 @@ def _on_span_finished(span: Span, streamed_chunks):
else:
# best effort to join chunks together
content = "".join([str(chunk) for chunk in streamed_chunks])
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=content, operation="chain")
if span.error or not integration.is_pc_sampled_span(span):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to set llm obs tags before we do the integration.is_pc_sampled_span?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. it looks like we recently changed llmobs_set_tags to run independently of is_pc_sampled_span (example). followed the same logic here, but i think even before we did something like is_pc_sampled_llmobs, which I think always returned True.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another clarification question here about this - do we move the check for span.error after setting tags because the span may contain an error that happened mid-stream but we still want to capture the outputs that were returned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we re-check span.error in llmobs_set_tags for llm, chat, and chain, which i kept before adding output messages/value properties

return
span.set_tag_str("langchain.response.outputs", integration.trunc(content))

return shared_stream(
Expand All @@ -989,6 +990,7 @@ def _on_span_finished(span: Span, streamed_chunks):
def traced_chat_stream(langchain, pin, func, instance, args, kwargs):
integration: LangChainIntegration = langchain._datadog_integration
llm_provider = instance._llm_type
model = _extract_model_name(instance)

def _on_span_started(span: Span):
if not integration.is_pc_sampled_span(span):
Expand All @@ -1004,12 +1006,19 @@ def _on_span_started(span: Span):
span.set_tag_str("langchain.request.%s.parameters.%s.%s" % (llm_provider, param, k), str(v))

def _on_span_finished(span: Span, streamed_chunks):
if span.error or not integration.is_pc_sampled_span(span):
joined_chunks = streamed_chunks[0]
for chunk in streamed_chunks[1:]:
joined_chunks += chunk # base message types support __add__ for concatenation
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=joined_chunks, operation="chat")
if (
span.error
or not integration.is_pc_sampled_span(span)
or streamed_chunks is None
or len(streamed_chunks) == 0
):
return
content = "".join([str(getattr(chunk, "content", chunk)) for chunk in streamed_chunks])
role = (
streamed_chunks[0].__class__.__name__.replace("Chunk", "") if streamed_chunks else None
) # AIMessageChunk --> AIeMessage
content = str(getattr(joined_chunks, "content", joined_chunks))
role = joined_chunks.__class__.__name__.replace("Chunk", "") # AIMessageChunk --> AIMessage
span.set_tag_str("langchain.response.content", integration.trunc(content))
if role:
span.set_tag_str("langchain.response.message_type", role)
Expand All @@ -1032,13 +1041,15 @@ def _on_span_finished(span: Span, streamed_chunks):
on_span_finished=_on_span_finished,
api_key=_extract_api_key(instance),
provider=llm_provider,
model=model,
)


@with_traced_module
def traced_llm_stream(langchain, pin, func, instance, args, kwargs):
integration: LangChainIntegration = langchain._datadog_integration
llm_provider = instance._llm_type
model = _extract_model_name(instance)

def _on_span_start(span: Span):
if not integration.is_pc_sampled_span(span):
Expand All @@ -1053,9 +1064,10 @@ def _on_span_start(span: Span):
span.set_tag_str("langchain.request.%s.parameters.%s.%s" % (llm_provider, param, k), str(v))

def _on_span_finished(span: Span, streamed_chunks):
content = "".join([str(chunk) for chunk in streamed_chunks])
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=content, operation="llm")
if span.error or not integration.is_pc_sampled_span(span):
return
content = "".join([str(chunk) for chunk in streamed_chunks])
span.set_tag_str("langchain.response.content", integration.trunc(content))

return shared_stream(
Expand All @@ -1070,6 +1082,7 @@ def _on_span_finished(span: Span, streamed_chunks):
on_span_finished=_on_span_finished,
api_key=_extract_api_key(instance),
provider=llm_provider,
model=model,
)


Expand Down
3 changes: 3 additions & 0 deletions ddtrace/contrib/internal/langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __next__(self):
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
self._dd_integration.metric(self._dd_span, "incr", "request.error", 1)
self._dd_span.finish()
raise


Expand All @@ -60,6 +61,7 @@ async def __anext__(self):
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
self._dd_integration.metric(self._dd_span, "incr", "request.error", 1)
self._dd_span.finish()
raise


Expand All @@ -79,6 +81,7 @@ def shared_stream(
"pin": pin,
"operation_id": f"{instance.__module__}.{instance.__class__.__name__}",
"interface_type": interface_type,
"submit_to_llmobs": True,
}

options.update(extra_options)
Expand Down
78 changes: 66 additions & 12 deletions ddtrace/llmobs/_integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _llmobs_set_tags(
elif operation == "chat":
self._llmobs_set_meta_tags_from_chat_model(span, args, kwargs, response, is_workflow=is_workflow)
elif operation == "chain":
self._llmobs_set_meta_tags_from_chain(span, inputs=kwargs, outputs=response)
self._llmobs_set_meta_tags_from_chain(span, args, kwargs, outputs=response)
elif operation == "embedding":
self._llmobs_set_meta_tags_from_embedding(span, args, kwargs, response, is_workflow=is_workflow)
elif operation == "retrieval":
Expand Down Expand Up @@ -129,16 +129,25 @@ def _llmobs_set_meta_tags_from_llm(

input_tag_key = INPUT_VALUE if is_workflow else INPUT_MESSAGES
output_tag_key = OUTPUT_VALUE if is_workflow else OUTPUT_MESSAGES
stream = span.get_tag("langchain.request.stream")

prompts = get_argument_value(args, kwargs, 0, "prompts")
prompts = get_argument_value(args, kwargs, 0, "input" if stream else "prompts")
if isinstance(prompts, str) or not isinstance(prompts, list):
prompts = [prompts]

span.set_tag_str(input_tag_key, safe_json([{"content": str(prompt)} for prompt in prompts]))
if stream:
# chat and llm take the same input types for streamed calls
span.set_tag_str(input_tag_key, safe_json(self._handle_stream_input_messages(prompts)))
else:
span.set_tag_str(input_tag_key, safe_json([{"content": str(prompt)} for prompt in prompts]))

if span.error:
span.set_tag_str(output_tag_key, safe_json([{"content": ""}]))
return
message_content = [{"content": completion[0].text} for completion in completions.generations]
if stream:
message_content = [{"content": completions}] # single completion for streams
else:
message_content = [{"content": completion[0].text} for completion in completions.generations]
span.set_tag_str(output_tag_key, safe_json(message_content))

def _llmobs_set_meta_tags_from_chat_model(
Expand All @@ -155,20 +164,36 @@ def _llmobs_set_meta_tags_from_chat_model(

input_tag_key = INPUT_VALUE if is_workflow else INPUT_MESSAGES
output_tag_key = OUTPUT_VALUE if is_workflow else OUTPUT_MESSAGES
stream = span.get_tag("langchain.request.stream")

input_messages = []
chat_messages = get_argument_value(args, kwargs, 0, "messages", optional=True) or []
for message_set in chat_messages:
for message in message_set:
content = message.get("content", "") if isinstance(message, dict) else getattr(message, "content", "")
role = getattr(message, "role", ROLE_MAPPING.get(message.type, ""))
input_messages.append({"content": str(content), "role": str(role)})
if stream:
chat_messages = get_argument_value(args, kwargs, 0, "input")
input_messages = self._handle_stream_input_messages(chat_messages)
else:
chat_messages = get_argument_value(args, kwargs, 0, "messages", optional=True) or []
if not isinstance(chat_messages, list):
chat_messages = [chat_messages]
for message_set in chat_messages:
for message in message_set:
content = (
message.get("content", "") if isinstance(message, dict) else getattr(message, "content", "")
)
role = getattr(message, "role", ROLE_MAPPING.get(message.type, ""))
input_messages.append({"content": str(content), "role": str(role)})
span.set_tag_str(input_tag_key, safe_json(input_messages))

if span.error:
span.set_tag_str(output_tag_key, json.dumps([{"content": ""}]))
return

output_messages = []
if stream:
content = chat_completions.content
role = chat_completions.__class__.__name__.replace("MessageChunk", "").lower() # AIMessageChunk --> ai
span.set_tag_str(output_tag_key, safe_json([{"content": content, "role": ROLE_MAPPING.get(role, "")}]))
return

for message_set in chat_completions.generations:
for chat_completion in message_set:
chat_completion_msg = chat_completion.message
Expand Down Expand Up @@ -196,9 +221,38 @@ def _extract_tool_calls(self, chat_completion_msg: Any) -> List[Dict[str, Any]]:
tool_calls_info.append(tool_call_info)
return tool_calls_info

def _llmobs_set_meta_tags_from_chain(self, span: Span, outputs: Any, inputs: Optional[Any] = None) -> None:
span.set_tag_str(SPAN_KIND, "workflow")
def _handle_stream_input_messages(self, inputs):
input_messages = []
if hasattr(inputs, "to_messages"): # isinstance(inputs, langchain_core.prompt_values.PromptValue)
inputs = inputs.to_messages()
elif not isinstance(inputs, list):
inputs = [inputs]
for inp in inputs:
inp_message = {}
content, role = None, None
if isinstance(inp, dict):
content = str(inp.get("content", ""))
role = inp.get("role")
elif hasattr(inp, "content"): # isinstance(inp, langchain_core.messages.BaseMessage)
content = str(inp.content)
role = inp.__class__.__name__
else:
content = str(inp)

inp_message["content"] = content
if role is not None:
inp_message["role"] = role
input_messages.append(inp_message)

return input_messages

def _llmobs_set_meta_tags_from_chain(self, span: Span, args, kwargs, outputs: Any) -> None:
sabrenner marked this conversation as resolved.
Show resolved Hide resolved
span.set_tag_str(SPAN_KIND, "workflow")
stream = span.get_tag("langchain.request.stream")
if stream:
inputs = get_argument_value(args, kwargs, 0, "input")
else:
inputs = kwargs
if inputs is not None:
formatted_inputs = self.format_io(inputs)
span.set_tag_str(INPUT_VALUE, safe_json(formatted_inputs))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
LLM Observability: LangChain streamed calls (``llm.stream``, ``chat_model.stream``, and ``chain.stream``) submit to LLM Observability.
Loading
Loading